From 38f8d153cdf4d41839bede5ac9e6c40ecadb5ceb Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 24 Jun 2024 14:29:35 +0000 Subject: [PATCH 001/156] feat: Initial implementation of global graphs Co-authored by: Mario Santa Cruz --- .gitignore | 2 + src/anemoi/graphs/__init__.py | 9 +- src/anemoi/graphs/edges/__init__.py | 0 src/anemoi/graphs/edges/attributes.py | 37 +++++++ src/anemoi/graphs/edges/connections.py | 131 +++++++++++++++++++++++ src/anemoi/graphs/edges/directional.py | 83 ++++++++++++++ src/anemoi/graphs/generate.py | 36 +++++++ src/anemoi/graphs/generate/__init__.py | 0 src/anemoi/graphs/generate/transforms.py | 95 ++++++++++++++++ src/anemoi/graphs/nodes/__init__.py | 0 src/anemoi/graphs/nodes/nodes.py | 93 ++++++++++++++++ src/anemoi/graphs/nodes/weights.py | 62 +++++++++++ src/anemoi/graphs/normalizer.py | 22 ++++ src/anemoi/graphs/utils.py | 108 +++++++++++++++++++ 14 files changed, 670 insertions(+), 8 deletions(-) create mode 100644 src/anemoi/graphs/edges/__init__.py create mode 100644 src/anemoi/graphs/edges/attributes.py create mode 100644 src/anemoi/graphs/edges/connections.py create mode 100644 src/anemoi/graphs/edges/directional.py create mode 100644 src/anemoi/graphs/generate.py create mode 100644 src/anemoi/graphs/generate/__init__.py create mode 100644 src/anemoi/graphs/generate/transforms.py create mode 100644 src/anemoi/graphs/nodes/__init__.py create mode 100644 src/anemoi/graphs/nodes/nodes.py create mode 100644 src/anemoi/graphs/nodes/weights.py create mode 100644 src/anemoi/graphs/normalizer.py create mode 100644 src/anemoi/graphs/utils.py diff --git a/.gitignore b/.gitignore index 2137d4c..f20d81f 100644 --- a/.gitignore +++ b/.gitignore @@ -186,3 +186,5 @@ _build/ *.sync _version.py *.code-workspace + +/config* \ No newline at end of file diff --git a/src/anemoi/graphs/__init__.py b/src/anemoi/graphs/__init__.py index eef2c1d..4e2516c 100644 --- a/src/anemoi/graphs/__init__.py +++ b/src/anemoi/graphs/__init__.py @@ -1,9 +1,2 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. +earth_radius = 6371.0 # km - -from ._version import __version__ as __version__ diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py new file mode 100644 index 0000000..b4bcf02 --- /dev/null +++ b/src/anemoi/graphs/edges/attributes.py @@ -0,0 +1,37 @@ +from abc import ABC +from abc import abstractmethod +from typing import Optional + +import numpy as np +from torch_geometric.data import HeteroData + +from anemoi.graphs.edges.directional import directional_edge_features +from anemoi.graphs.normalizer import NormalizerMixin +from anemoi.utils.logger import get_code_logger + +logger = get_code_logger(__name__) + + +class BaseEdgeAttribute(ABC, NormalizerMixin): + norm: Optional[str] = None + + @abstractmethod + def compute(self, graph: HeteroData, *args, **kwargs): ... + + def __call__(self, *args, **kwargs): + values = self.compute(*args, **kwargs) + if values.ndim == 1: + values = values[:, np.newaxis] + return self.normalize(values) + + +class DirectionalFeatures(BaseEdgeAttribute): + norm: Optional[str] = None + luse_rotated_features: bool = False + + def compute(self, graph: HeteroData, src_name: str, dst_name: str): + edge_index = graph[(src_name, "to", dst_name)].edge_index + src_coords = graph[src_name].x.numpy()[edge_index[0]].T + dst_coords = graph[dst_name].x.numpy()[edge_index[1]].T + edge_dirs = directional_edge_features(src_coords, dst_coords, self.luse_rotated_features).T + return edge_dirs diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/connections.py new file mode 100644 index 0000000..dcface3 --- /dev/null +++ b/src/anemoi/graphs/edges/connections.py @@ -0,0 +1,131 @@ +from abc import abstractmethod +from dataclasses import dataclass +from typing import Optional + +import networkx as nx +import numpy as np +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize +from torch_geometric.data import HeteroData +from torch_geometric.data.storage import NodeStorage + +from anemoi.graphs import earth_radius +from anemoi.graphs.utils import get_grid_reference_distance + +import logging + +logger = logging.getLogger(__name__) + + +class BaseEdgeBuilder: + """Base class for edge builders.""" + + def __init__(self, src_name: str, dst_name: str): + super().__init__() + self.src_name = src_name + self.dst_name = dst_name + + @abstractmethod + def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): ... + + def register_edges(self, graph, head_indices, tail_indices): + graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype(np.int32) + return graph + + def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarray): + num_edges = graph[(self.src_name, "to", self.dst_name)].num_edges + assert ( + values.shape[0] == num_edges + ), f"Number of edge features ({values.shape[0]}) must match number of edges ({num_edges})." + graph[self.src_name, "to", self.dst_name][name] = values.reshape(num_edges, -1) # TODO: Check the [name] part works + return graph + + def prepare_node_data(self, graph: HeteroData): + return graph[self.src_name], graph[self.dst_name] + + def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + # Get source and destination nodes. + src_nodes, dst_nodes = self.prepare_node_data(graph) + + # Compute adjacency matrix. + adjmat = self.get_adj_matrix(src_nodes, dst_nodes) + + # Normalize adjacency matrix. + adjmat_norm = self.normalize_adjmat(adjmat) + + # Add edges to the graph and register normed distance. + graph = self.register_edges(graph, adjmat.col, adjmat.row) + + self.register_edge_attribute(graph, "normed_dist", adjmat_norm.data) + if attrs_config is not None: + for attr_name, attr_cfg in attrs_config.items(): + attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name) + graph = self.register_edge_attribute(graph, attr_name, attr_values) + + return graph + + def normalize_adjmat(self, adjmat): + """Normalize a sparse adjacency matrix.""" + adjmat_norm = normalize(adjmat, norm="l1", axis=1) + adjmat_norm.data = 1.0 - adjmat_norm.data + return adjmat_norm + + +class KNNEdgeBuilder(BaseEdgeBuilder): + """Computes KNN based edges and adds them to the graph.""" + + def __init__(self, src_name: str, dst_name: str, num_nearest_neighbours: int): + super().__init__(src_name, dst_name) + assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" + assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" + self.num_nearest_neighbours = num_nearest_neighbours + + def get_adj_matrix(self, src_nodes: np.ndarray, dst_nodes: np.ndarray): + assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" + logger.debug( + "Using %d nearest neighbours for KNN-Edges between %s and %s.", + self.num_nearest_neighbours, + self.src_name, + self.dst_name, + ) + + nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + nearest_neighbour.fit(src_nodes.x.numpy()) + adj_matrix = nearest_neighbour.kneighbors_graph( + dst_nodes.x.numpy(), + n_neighbors=self.num_nearest_neighbours, + mode="distance", + ).tocoo() + return adj_matrix + + +class CutOffEdgeBuilder(BaseEdgeBuilder): + """Computes cut-off based edges and adds them to the graph.""" + + def __init__(self, src_name: str, dst_name: str, cutoff_factor: float): + super().__init__(src_name, dst_name) + assert isinstance(cutoff_factor, float), "Cutoff factor must be a float" + assert cutoff_factor > 0, "Cutoff factor must be positive" + self.cutoff_factor = cutoff_factor + + def get_cutoff_radius(self, dst_nodes: NodeStorage, mask_attr: Optional[torch.Tensor] = None): + mask = dst_nodes[mask_attr] if mask_attr is not None else None + dst_grid_reference_distance = get_grid_reference_distance(dst_nodes.x, mask) + radius = dst_grid_reference_distance * self.cutoff_factor + return radius + + def prepare_node_data(self, graph: HeteroData): + self.radius = self.get_cutoff_radius(graph) + return super().prepare_node_data(graph) + + def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): + logger.debug("Using cut-off radius of %.1f km.", self.radius * earth_radius) + + nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + nearest_neighbour.fit(src_nodes.x) + adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo() + return adj_matrix + diff --git a/src/anemoi/graphs/edges/directional.py b/src/anemoi/graphs/edges/directional.py new file mode 100644 index 0000000..aac7c79 --- /dev/null +++ b/src/anemoi/graphs/edges/directional.py @@ -0,0 +1,83 @@ +from typing import Optional + +import numpy as np +from scipy.spatial.transform import Rotation + +from anemoi.graphs.generate.transforms import direction_vec +from anemoi.graphs.generate.transforms import to_sphere_xyz + + +def get_rotation_from_unit_vecs(points: np.ndarray, reference: np.ndarray) -> Rotation: + """Compute rotation matrix of a set of points with respect to a reference vector. + + Parameters + ---------- + points : np.ndarray of shape (num_points, 3) + The points to compute the direction vector. + reference : np.ndarray of shape (3, ) + The reference vector. + + Returns + ------- + Rotation + The rotation matrix that aligns the points with the reference vector. + """ + assert points.shape[1] == 3, "Points must be in 3D" + v_unit = direction_vec(points, reference) + theta = np.arccos(np.dot(points, reference)) + return Rotation.from_rotvec(np.transpose(v_unit * theta)) + + +def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: Optional[np.ndarray] = None) -> np.ndarray: + """Compute the direction of the edge joining the nodes considered. + + Parameters + ---------- + loc1 : np.ndarray of shape (2, num_points) + Location of the head nodes. + loc2 : np.ndarray + Location of the tail nodes. + pole_vec : np.ndarray, optional + The pole vector to rotate the points to. Defaults to the north pole. + + Returns + ------- + np.ndarray of shape (3, num_points) + The direction of the edge after rotating the north pole. + """ + if pole_vec is None: + pole_vec = np.array([0, 0, 1]) + + # all will be rotated relative to destination node + loc1_xyz = to_sphere_xyz(loc1, 1.0) + loc2_xyz = to_sphere_xyz(loc2, 1.0) + r = get_rotation_from_unit_vecs(loc2_xyz, pole_vec) + direction = direction_vec(r.apply(loc1_xyz), pole_vec) + return direction / np.sqrt(np.power(direction, 2).sum(axis=0)) + + +def directional_edge_features(loc1: np.ndarray, loc2: np.ndarray, relative_to_rotated_target: bool = True) -> np.ndarray: + """Compute features of the edge joining the nodes considered. + + It computes the direction of the edge after rotating the north pole. + + Parameters + ---------- + loc1 : np.ndarray of shpae (2, num_points) + Location of the head node. + loc2 : np.ndarray of shape (2, num_points) + Location of the tail node. + relative_to_rotated_target : bool, optional + Whether to rotate the north pole to the target node. Defaults to True. + + Returns + ------- + np.ndarray of shape of (2, num_points) + Direction of the edge after rotation the north pole. + """ + if relative_to_rotated_target: + rotation = compute_directions(loc1, loc2) + assert np.allclose(rotation[2], 0), "Rotation should be aligned with the north pole" + return rotation[:2] + + return loc2 - loc1 diff --git a/src/anemoi/graphs/generate.py b/src/anemoi/graphs/generate.py new file mode 100644 index 0000000..e44529e --- /dev/null +++ b/src/anemoi/graphs/generate.py @@ -0,0 +1,36 @@ +from abc import ABC +from abc import abstractmethod + +import hydra +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch_geometric.data import HeteroData + +import logging + +logger = logging.getLogger(__name__) + + +def generate_graph(graph_config): + graph = HeteroData() + + for name, nodes_cfg in graph_config.nodes.items(): + graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {})) + + for edges_cfg in graph_config.edges: + graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) + + return graph + + +@hydra.main(version_base=None, config_path="../config", config_name="config") +def main(config: DictConfig): + + graph = generate_graph(config) + + return graph + + +if __name__ == "__main__": + main() diff --git a/src/anemoi/graphs/generate/__init__.py b/src/anemoi/graphs/generate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/anemoi/graphs/generate/transforms.py b/src/anemoi/graphs/generate/transforms.py new file mode 100644 index 0000000..5204484 --- /dev/null +++ b/src/anemoi/graphs/generate/transforms.py @@ -0,0 +1,95 @@ +import numpy as np + + +def cartesian_to_latlon_degrees(xyz: np.ndarray) -> np.ndarray: + """3D to lat-lon conversion. + + Convert 3D coordinates of points to the (lat, lon) on the sphere containing + them. + + Parameters + ---------- + xyz : np.ndarray + The 3D coordinates of points. + + Returns + ------- + np.ndarray + A 2D array of lat-lon coordinates of shape (N, 2). + """ + lat = np.arcsin(xyz[..., 2] / (xyz**2).sum(axis=1)) * 180.0 / np.pi + lon = np.arctan2(xyz[..., 1], xyz[..., 0]) * 180.0 / np.pi + return np.array((lat, lon), dtype=np.float32).transpose() + + +def cartesian_to_latlon_rad(xyz: np.ndarray) -> np.ndarray: + """Degrees to radians conversion. + + Convert 3D coordinates of points to its coordinates on the sphere containing + them. + + Parameters + ---------- + xyz : np.ndarray + The 3D coordinates of points. + + Returns + ------- + np.ndarray + A 2D array of the coordinates of shape (N, 2) in radians. + """ + lat = np.arcsin(xyz[..., 2] / (xyz**2).sum(axis=1)) + lon = np.arctan2(xyz[..., 1], xyz[..., 0]) + return np.array((lat, lon), dtype=np.float32).transpose() + + +def to_sphere_xyz(loc: tuple[np.ndarray, np.ndarray], radius: float = 1) -> np.ndarray: + """Convert planar coordinates to 3D coordinates in a sphere. + + Parameters + ---------- + loc : np.ndarray + The 2D coordinates of the points, in radians. + radius : float, optional + The radius of the sphere containing los points. Defaults to the unit sphere. + + Returns + ------- + np.array of shape (3, num_points) + 3D coordinates of the points in the sphere. + """ + latr, lonr = loc[0], loc[1] + x = radius * np.cos(latr) * np.cos(lonr) + y = radius * np.cos(latr) * np.sin(lonr) + z = radius * np.sin(latr) + return np.array((x, y, z)).T + + +def direction_vec(points: np.ndarray, reference: np.ndarray, epsilon: float = 10e-11) -> np.ndarray: + """Direction vector computation. + + Compute the direction vector of a set of points with respect to a reference + vector. + + Parameters + ---------- + points : np.array of shape (num_points, 3) + The points to compute the direction vector. + reference : np.array of shape (3, ) + The reference vector. + epsilon : float, optional + The value to add to the first vector to avoid division by zero. Defaults to 10e-11. + + Returns + ------- + np.array of shape (3, num_points) + The direction vector of the cross product of the two vectors. + """ + v = np.cross(points, reference) + vnorm1 = np.power(v, 2).sum(axis=-1) + redo_idx = np.where(vnorm1 < epsilon)[0] + if len(redo_idx) > 0: + points[redo_idx] += epsilon + v = np.cross(points, reference) + vnorm1 = np.power(v, 2).sum(axis=-1) + return v.T / np.sqrt(vnorm1) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py new file mode 100644 index 0000000..d00dd62 --- /dev/null +++ b/src/anemoi/graphs/nodes/nodes.py @@ -0,0 +1,93 @@ +from abc import abstractmethod +from pathlib import Path +from typing import Optional +from typing import Union + +import h3 +import numpy as np +import torch +from abc import ABC +from anemoi.datasets import open_dataset +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from aifs.graphs import GraphBuilder +from aifs.graphs.generate.hexagonal import create_hexagonal_nodes +from aifs.graphs.generate.icosahedral import create_icosahedral_nodes +import logging + +logger = logging.getLogger(__name__) +earth_radius = 6371.0 # km + + +def latlon_to_radians(coords: np.ndarray) -> np.ndarray: + return np.deg2rad(coords) + + +def rad_to_latlon(coords: np.ndarray) -> np.ndarray: + """Converts coordinates from radians to degrees. + + Parameters + ---------- + coords : np.ndarray + Coordinates in radians. + + Returns + ------- + np.ndarray + _description_ + """ + return np.rad2deg(coords) + + +class BaseNodeBuilder(ABC): + + def register_nodes(self, graph: HeteroData, name: str) -> None: + graph[name].x = self.get_coordinates() + graph[name].node_type = type(self).__name__ + return graph + + def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: + for nodes_attr_name, attr_cfg in config.items(): + graph[name][nodes_attr_name] = instantiate(attr_cfg).get_weights(graph[name]) + return graph + + @abstractmethod + def get_coordinates(self) -> np.ndarray: ... + + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: + coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) + coords = np.deg2rad(coords) + # TODO: type needs to be variable? + return torch.tensor(coords, dtype=torch.float32) + + def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: + graph = self.register_nodes(graph, name) + graph = self.register_attributes(graph, name, attr_config) + return graph + + +class ZarrNodes(BaseNodeBuilder): + """Nodes from Zarr dataset.""" + + def __init__(self, dataset: DotDict) -> None: + logger.info("Reading the dataset from %s.", dataset) + self.ds = open_dataset(dataset) + + def get_coordinates(self) -> torch.Tensor: + return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) + + +class NPZNodes(BaseNodeBuilder): + """Nodes from NPZ defined grids.""" + + def __init__(self, resolution: str, grid_definition_path: str) -> None: + self.resolution = resolution + self.grid_definition_path = grid_definition_path + self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + + def get_coordinates(self) -> np.ndarray: + coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) + return coords diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py new file mode 100644 index 0000000..e2249f1 --- /dev/null +++ b/src/anemoi/graphs/nodes/weights.py @@ -0,0 +1,62 @@ +from abc import ABC +from abc import abstractmethod +from typing import Optional + +import numpy as np +import torch +from torch_geometric.data.storage import NodeStorage + +from anemoi.graphs.generate.transforms import to_sphere_xyz +from scipy.spatial import SphericalVoronoi +from anemoi.graphs.normalizer import NormalizerMixin +import logging + +logger = logging.getLogger(__name__) + +class BaseWeights(ABC, NormalizerMixin): + """Base class for the weights of the nodes.""" + + def __init__(self, norm: Optional[str] = None): + self.norm = norm + + @abstractmethod + def compute(self, nodes: NodeStorage, *args, **kwargs): ... + + def get_weights(self, *args, **kwargs): + weights = self.compute(*args, **kwargs) + if weights.ndim == 1: + weights = weights[:, np.newaxis] + return self.normalize(weights) + + +class UniformWeights(BaseWeights): + """Implements a uniform weight for the nodes.""" + + def __init__(self, norm: str = "unit-max"): + self.norm = norm + + def compute(self, nodes: NodeStorage) -> np.ndarray: + return torch.ones(nodes.num_nodes) + + +class AreaWeights(BaseWeights): + """Implements the area of the nodes as the weights.""" + + def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array[0, 0, 0]): + # Weighting of the nodes + self.norm: str = norm + self.radius: float = radius + self.centre: np.ndarray = centre + + def compute(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: + # TODO: Check if works + latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] + points = to_sphere_xyz((latitudes, longitudes)) + sv = SphericalVoronoi(points, self.radius, self.centre) + area_weights = sv.calculate_areas() + logger.debug( + "There are %d of weights, which (unscaled) add up a total weight of %.2f.", + len(area_weights), + np.array(area_weights).sum(), + ) + return area_weights diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py new file mode 100644 index 0000000..3a6bce6 --- /dev/null +++ b/src/anemoi/graphs/normalizer.py @@ -0,0 +1,22 @@ +import numpy as np +import logging + +logger = logging.getLogger(__name__) + + +class NormalizerMixin: + def normalize(self, values: np.ndarray) -> np.ndarray: + if self.norm is None: + logger.debug("Node weights are not normalized.") + return values + if self.norm == "l1": + return values / np.sum(values) + if self.norm == "l2": + return values / np.linalg.norm(values) + if self.norm == "unit-max": + return values / np.amax(values) + if self.norm == "unit-sum": + return values / np.sum(values) + if self.norm == "unit-std": + return values / np.std(values) + raise ValueError("Weight normalization must be 'l1', 'l2', 'unit-max' 'unit-sum' or 'unit-std'.") diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py new file mode 100644 index 0000000..9eed60a --- /dev/null +++ b/src/anemoi/graphs/utils.py @@ -0,0 +1,108 @@ +from typing import Optional + +import numpy as np +import torch +from sklearn.neighbors import NearestNeighbors + + +def get_nearest_neighbour(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] = None) -> NearestNeighbors: + """Get NearestNeighbour object fitted to coordinates. + + Parameters + ---------- + coords_rad : torch.Tensor + corrdinates in radians + mask : Optional[torch.Tensor], optional + mask to remove nodes, by default None + + Returns + ------- + NearestNeighbors + fitted NearestNeighbour object + """ + assert mask is None or mask.shape == (coords_rad.shape[0], 1), "Mask must have the same shape as the number of nodes." + + nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + + nearest_neighbour.fit(coords_rad) + + return nearest_neighbour + + +def get_grid_reference_distance(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] = None) -> float: + """Get the reference distance of the grid. + + It is the maximum distance of a node in the mesh with respect to its nearest neighbour. + + Parameters + ---------- + coords_rad : torch.Tensor + corrdinates in radians + mask : Optional[torch.Tensor], optional + mask to remove nodes, by default None + + Returns + ------- + float + The reference distance of the grid. + """ + nearest_neighbours = get_nearest_neighbour(coords_rad, mask) + dists, _ = nearest_neighbours.kneighbors(coords_rad, n_neighbors=2, return_distance=True) + return dists[dists > 0].max() + + +def add_margin(lats: np.ndarray, lons: np.ndarray, margin: float) -> tuple[np.ndarray, np.ndarray]: + """Add a margin to the convex hull of the points considered. + + For each point (lat, lon) add 8 points around it, each at a distance of `margin` from the original point. + + Arguments + --------- + lats : np.ndarray + Latitudes of the points considered. + lons : np.ndarray + Longitudes of the points considered. + margin : float + The margin to add to the convex hull. + + Returns + ------- + latitudes : np.ndarray + Latitudes of the points considered, including the margin. + longitudes : np.ndarray + Longitudes of the points considered, including the margin. + """ + assert margin >= 0, "Margin must be non-negative" + if margin == 0: + return lats, lons + + latitudes, longitudes = [], [] + for lat_sign in [-1, 0, 1]: + for lon_sign in [-1, 0, 1]: + latitudes.append(lats + lat_sign * margin) + longitudes.append(lons + lon_sign * margin) + + return np.concatenate(latitudes), np.concatenate(longitudes) + + +def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int: + """Index position of vector. + + Get the index position of a vector in a matrix. + + Parameters + ---------- + vector : torch.Tensor of shape (N, ) + Vector to get its position in the matrix. + tensor : torch.Tensor of shape (M, N,) + Tensor in which the position is searched. + + Returns + ------- + int + Index position of the tensor in the other tensor. -1 if tensor1 is not in tensor2 + """ + mask = torch.all(tensor == vector, axis=1) + if mask.any(): + return int(torch.where(mask)[0]) + return -1 From 9dc2cec7dddb10daadbc1ddf2e7402fc3a295c01 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 25 Jun 2024 08:27:21 +0000 Subject: [PATCH 002/156] add dependencies --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 57fc0ff..e2a0ac1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,10 @@ dynamic = [ "version", ] dependencies = [ - "anemoi-datasets", + "anemoi-datasets[data]>=0.3.3", + "torch>=2.2", + "torch-geometric>=2.3.1,<2.5", + "anemoi-utils>=0.1.3", ] optional-dependencies.all = [ From f1fe18fe3db503eef593762278e1ad3ba728f315 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 25 Jun 2024 11:58:42 +0000 Subject: [PATCH 003/156] add cli command --- .gitignore | 2 +- README.md | 8 +++ pyproject.toml | 2 +- src/anemoi/graphs/__init__.py | 10 +++- src/anemoi/graphs/commands/create.py | 28 +++++++++++ src/anemoi/graphs/commands/hello.py | 32 ------------ src/anemoi/graphs/create.py | 68 ++++++++++++++++++++++++++ src/anemoi/graphs/edges/attributes.py | 4 +- src/anemoi/graphs/edges/connections.py | 17 ++++--- src/anemoi/graphs/edges/directional.py | 4 +- src/anemoi/graphs/generate.py | 36 -------------- src/anemoi/graphs/nodes/nodes.py | 12 +---- src/anemoi/graphs/nodes/weights.py | 18 +++---- src/anemoi/graphs/normalizer.py | 3 +- src/anemoi/graphs/utils.py | 5 +- 15 files changed, 145 insertions(+), 104 deletions(-) create mode 100644 src/anemoi/graphs/commands/create.py delete mode 100644 src/anemoi/graphs/commands/hello.py create mode 100644 src/anemoi/graphs/create.py delete mode 100644 src/anemoi/graphs/generate.py diff --git a/.gitignore b/.gitignore index f20d81f..05d7042 100644 --- a/.gitignore +++ b/.gitignore @@ -187,4 +187,4 @@ _build/ _version.py *.code-workspace -/config* \ No newline at end of file +/config* diff --git a/README.md b/README.md index f55e7da..a4fc900 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,14 @@ Install via `pip` with: $ pip install anemoi-graphs ``` +## Usage + +Create you graph + +``` +$ anemoi-graphs create recipe.yaml my_graph.pt +``` + ## License ``` diff --git a/pyproject.toml b/pyproject.toml index e2a0ac1..74c3e0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,9 +51,9 @@ dynamic = [ ] dependencies = [ "anemoi-datasets[data]>=0.3.3", + "anemoi-utils>=0.3.6", "torch>=2.2", "torch-geometric>=2.3.1,<2.5", - "anemoi-utils>=0.1.3", ] optional-dependencies.all = [ diff --git a/src/anemoi/graphs/__init__.py b/src/anemoi/graphs/__init__.py index 4e2516c..80d19fb 100644 --- a/src/anemoi/graphs/__init__.py +++ b/src/anemoi/graphs/__init__.py @@ -1,2 +1,10 @@ -earth_radius = 6371.0 # km +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from ._version import __version__ +earth_radius = 6371.0 # km diff --git a/src/anemoi/graphs/commands/create.py b/src/anemoi/graphs/commands/create.py new file mode 100644 index 0000000..18b3127 --- /dev/null +++ b/src/anemoi/graphs/commands/create.py @@ -0,0 +1,28 @@ +from anemoi.graphs.create import GraphCreator + +from . import Command + + +class Create(Command): + """Create a graph.""" + + internal = True + timestamp = True + + def add_arguments(self, command_parser): + command_parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing files. This will delete the target graph if it already exists.", + ) + command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the graph.") + command_parser.add_argument("path", help="Path to store the created graph.") + + def run(self, args): + kwargs = vars(args) + + c = GraphCreator(**kwargs) + c.create() + + +command = Create diff --git a/src/anemoi/graphs/commands/hello.py b/src/anemoi/graphs/commands/hello.py deleted file mode 100644 index 12a0495..0000000 --- a/src/anemoi/graphs/commands/hello.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -"""Command place holder. Delete when we have real commands. - -""" - -from . import Command - - -def say_hello(greetings, who): - print(greetings, who) - - -class Hello(Command): - - def add_arguments(self, command_parser): - command_parser.add_argument("--greetings", default="hello") - command_parser.add_argument("--who", default="world") - - def run(self, args): - say_hello(args.greetings, args.who) - - -command = Hello diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py new file mode 100644 index 0000000..cbbe879 --- /dev/null +++ b/src/anemoi/graphs/create.py @@ -0,0 +1,68 @@ +import os +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + +import logging + +logger = logging.getLogger(__name__) + + +def generate_graph(graph_config: DotDict) -> HeteroData: + graph = HeteroData() + + for name, nodes_cfg in graph_config.nodes.items(): + graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {})) + + for edges_cfg in graph_config.edges: + graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) + + return graph + + +class GraphCreator: + def __init__( + self, + path, + config=None, + cache=None, + print=print, + overwrite=False, + **kwargs, + ): + self.path = path # Output path + self.config = config + self.cache = cache + self.print = print + self.overwrite = overwrite + + def init(self): + assert os.path.exists(self.config), f"Path {self.config} does not exist." + + if self._path_readable() and not self.overwrite: + raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") + + def load(self) -> HeteroData: + config = DotDict.from_file(self.config) + graph = generate_graph(config) + return graph + + def save(self, graph: HeteroData) -> None: + if not os.path.exists(self.path) or self.overwrite: + torch.save(graph, self.path) + self.print(f"Graph saved at {self.path}.") + + def create(self): + self.init() + graph = self.load() + self.save(graph) + + def _path_readable(self) -> bool: + import torch + + try: + torch.load(self.path, "r") + return True + except FileNotFoundError: + return False diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index b4bcf02..15f6caf 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -2,14 +2,14 @@ from abc import abstractmethod from typing import Optional +import logging import numpy as np from torch_geometric.data import HeteroData from anemoi.graphs.edges.directional import directional_edge_features from anemoi.graphs.normalizer import NormalizerMixin -from anemoi.utils.logger import get_code_logger -logger = get_code_logger(__name__) +logger = logging.getLogger(__name__) class BaseEdgeAttribute(ABC, NormalizerMixin): diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/connections.py index dcface3..6bf057e 100644 --- a/src/anemoi/graphs/edges/connections.py +++ b/src/anemoi/graphs/edges/connections.py @@ -1,8 +1,7 @@ +import logging from abc import abstractmethod -from dataclasses import dataclass from typing import Optional -import networkx as nx import numpy as np import torch from anemoi.utils.config import DotDict @@ -15,8 +14,6 @@ from anemoi.graphs import earth_radius from anemoi.graphs.utils import get_grid_reference_distance -import logging - logger = logging.getLogger(__name__) @@ -32,7 +29,9 @@ def __init__(self, src_name: str, dst_name: str): def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): ... def register_edges(self, graph, head_indices, tail_indices): - graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype(np.int32) + graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype( + np.int32 + ) return graph def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarray): @@ -40,7 +39,9 @@ def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarr assert ( values.shape[0] == num_edges ), f"Number of edge features ({values.shape[0]}) must match number of edges ({num_edges})." - graph[self.src_name, "to", self.dst_name][name] = values.reshape(num_edges, -1) # TODO: Check the [name] part works + graph[self.src_name, "to", self.dst_name][name] = values.reshape( + num_edges, -1 + ) # TODO: Check the [name] part works return graph def prepare_node_data(self, graph: HeteroData): @@ -111,7 +112,8 @@ def __init__(self, src_name: str, dst_name: str, cutoff_factor: float): assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor - def get_cutoff_radius(self, dst_nodes: NodeStorage, mask_attr: Optional[torch.Tensor] = None): + def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] = None): + dst_nodes = graph[self.dst_name] mask = dst_nodes[mask_attr] if mask_attr is not None else None dst_grid_reference_distance = get_grid_reference_distance(dst_nodes.x, mask) radius = dst_grid_reference_distance * self.cutoff_factor @@ -128,4 +130,3 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): nearest_neighbour.fit(src_nodes.x) adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo() return adj_matrix - diff --git a/src/anemoi/graphs/edges/directional.py b/src/anemoi/graphs/edges/directional.py index aac7c79..9ac3a61 100644 --- a/src/anemoi/graphs/edges/directional.py +++ b/src/anemoi/graphs/edges/directional.py @@ -56,7 +56,9 @@ def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: Optional[np return direction / np.sqrt(np.power(direction, 2).sum(axis=0)) -def directional_edge_features(loc1: np.ndarray, loc2: np.ndarray, relative_to_rotated_target: bool = True) -> np.ndarray: +def directional_edge_features( + loc1: np.ndarray, loc2: np.ndarray, relative_to_rotated_target: bool = True +) -> np.ndarray: """Compute features of the edge joining the nodes considered. It computes the direction of the edge after rotating the north pole. diff --git a/src/anemoi/graphs/generate.py b/src/anemoi/graphs/generate.py deleted file mode 100644 index e44529e..0000000 --- a/src/anemoi/graphs/generate.py +++ /dev/null @@ -1,36 +0,0 @@ -from abc import ABC -from abc import abstractmethod - -import hydra -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from omegaconf import DictConfig -from torch_geometric.data import HeteroData - -import logging - -logger = logging.getLogger(__name__) - - -def generate_graph(graph_config): - graph = HeteroData() - - for name, nodes_cfg in graph_config.nodes.items(): - graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {})) - - for edges_cfg in graph_config.edges: - graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) - - return graph - - -@hydra.main(version_base=None, config_path="../config", config_name="config") -def main(config: DictConfig): - - graph = generate_graph(config) - - return graph - - -if __name__ == "__main__": - main() diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py index d00dd62..8de951f 100644 --- a/src/anemoi/graphs/nodes/nodes.py +++ b/src/anemoi/graphs/nodes/nodes.py @@ -1,23 +1,15 @@ +import logging +from abc import ABC from abc import abstractmethod from pathlib import Path -from typing import Optional -from typing import Union -import h3 import numpy as np import torch -from abc import ABC from anemoi.datasets import open_dataset from anemoi.utils.config import DotDict from hydra.utils import instantiate -from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData -from aifs.graphs import GraphBuilder -from aifs.graphs.generate.hexagonal import create_hexagonal_nodes -from aifs.graphs.generate.icosahedral import create_icosahedral_nodes -import logging - logger = logging.getLogger(__name__) earth_radius = 6371.0 # km diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py index e2249f1..3afe523 100644 --- a/src/anemoi/graphs/nodes/weights.py +++ b/src/anemoi/graphs/nodes/weights.py @@ -1,18 +1,19 @@ +import logging from abc import ABC from abc import abstractmethod from typing import Optional import numpy as np import torch +from scipy.spatial import SphericalVoronoi from torch_geometric.data.storage import NodeStorage from anemoi.graphs.generate.transforms import to_sphere_xyz -from scipy.spatial import SphericalVoronoi from anemoi.graphs.normalizer import NormalizerMixin -import logging logger = logging.getLogger(__name__) + class BaseWeights(ABC, NormalizerMixin): """Base class for the weights of the nodes.""" @@ -32,9 +33,6 @@ def get_weights(self, *args, **kwargs): class UniformWeights(BaseWeights): """Implements a uniform weight for the nodes.""" - def __init__(self, norm: str = "unit-max"): - self.norm = norm - def compute(self, nodes: NodeStorage) -> np.ndarray: return torch.ones(nodes.num_nodes) @@ -42,14 +40,14 @@ def compute(self, nodes: NodeStorage) -> np.ndarray: class AreaWeights(BaseWeights): """Implements the area of the nodes as the weights.""" - def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array[0, 0, 0]): + def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])): + super().__init__(norm=norm) + # Weighting of the nodes - self.norm: str = norm - self.radius: float = radius - self.centre: np.ndarray = centre + self.radius = radius + self.centre = centre def compute(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: - # TODO: Check if works latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] points = to_sphere_xyz((latitudes, longitudes)) sv = SphericalVoronoi(points, self.radius, self.centre) diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index 3a6bce6..593de06 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -1,6 +1,7 @@ -import numpy as np import logging +import numpy as np + logger = logging.getLogger(__name__) diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py index 9eed60a..1a25134 100644 --- a/src/anemoi/graphs/utils.py +++ b/src/anemoi/graphs/utils.py @@ -20,7 +20,10 @@ def get_nearest_neighbour(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] NearestNeighbors fitted NearestNeighbour object """ - assert mask is None or mask.shape == (coords_rad.shape[0], 1), "Mask must have the same shape as the number of nodes." + assert mask is None or mask.shape == ( + coords_rad.shape[0], + 1, + ), "Mask must have the same shape as the number of nodes." nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) From b8b558dfd95376497a07b6f9ea5844f6fbdc10a2 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 25 Jun 2024 12:03:20 +0000 Subject: [PATCH 004/156] Ignore .pt files --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 05d7042..d2278fa 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,4 @@ _version.py *.code-workspace /config* +*.pt \ No newline at end of file From 7f6f4bdf16002577bedd2a426aaf1c7652da7ef7 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 25 Jun 2024 12:32:00 +0000 Subject: [PATCH 005/156] run pre-commit --- .gitignore | 2 +- src/anemoi/graphs/create.py | 4 ++-- src/anemoi/graphs/edges/attributes.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index d2278fa..1b49006 100644 --- a/.gitignore +++ b/.gitignore @@ -188,4 +188,4 @@ _version.py *.code-workspace /config* -*.pt \ No newline at end of file +*.pt diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index cbbe879..d2b803d 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -1,11 +1,11 @@ +import logging import os + import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch_geometric.data import HeteroData -import logging - logger = logging.getLogger(__name__) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 15f6caf..0aa297d 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -1,8 +1,8 @@ +import logging from abc import ABC from abc import abstractmethod from typing import Optional -import logging import numpy as np from torch_geometric.data import HeteroData From d5f67fd6a01da6aed3afe6b275175e1beeca0b6d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 26 Jun 2024 11:12:52 +0000 Subject: [PATCH 006/156] docstring + log erros --- src/anemoi/graphs/edges/__init__.py | 4 ++++ src/anemoi/graphs/nodes/__init__.py | 4 ++++ src/anemoi/graphs/nodes/nodes.py | 23 +---------------------- src/anemoi/graphs/nodes/weights.py | 7 ++++--- src/anemoi/graphs/normalizer.py | 10 ++++++++-- 5 files changed, 21 insertions(+), 27 deletions(-) diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index e69de29..29875d0 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -0,0 +1,4 @@ +from .connections import CutOffEdgeBuilder +from .connections import KNNEdgeBuilder + +__all__ = ["KNNEdgeBuilder", "CutOffEdgeBuilder"] diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index e69de29..5458495 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -0,0 +1,4 @@ +from .nodes import NPZNodes +from .nodes import ZarrNodes + +__all__ = ["ZarrNodes", "NPZNodes"] diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py index 8de951f..886125e 100644 --- a/src/anemoi/graphs/nodes/nodes.py +++ b/src/anemoi/graphs/nodes/nodes.py @@ -11,30 +11,10 @@ from torch_geometric.data import HeteroData logger = logging.getLogger(__name__) -earth_radius = 6371.0 # km - - -def latlon_to_radians(coords: np.ndarray) -> np.ndarray: - return np.deg2rad(coords) - - -def rad_to_latlon(coords: np.ndarray) -> np.ndarray: - """Converts coordinates from radians to degrees. - - Parameters - ---------- - coords : np.ndarray - Coordinates in radians. - - Returns - ------- - np.ndarray - _description_ - """ - return np.rad2deg(coords) class BaseNodeBuilder(ABC): + """Base class for node builders.""" def register_nodes(self, graph: HeteroData, name: str) -> None: graph[name].x = self.get_coordinates() @@ -52,7 +32,6 @@ def get_coordinates(self) -> np.ndarray: ... def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) coords = np.deg2rad(coords) - # TODO: type needs to be variable? return torch.tensor(coords, dtype=torch.float32) def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py index 3afe523..25419cc 100644 --- a/src/anemoi/graphs/nodes/weights.py +++ b/src/anemoi/graphs/nodes/weights.py @@ -23,18 +23,19 @@ def __init__(self, norm: Optional[str] = None): @abstractmethod def compute(self, nodes: NodeStorage, *args, **kwargs): ... - def get_weights(self, *args, **kwargs): + def get_weights(self, *args, **kwargs) -> torch.Tensor: weights = self.compute(*args, **kwargs) if weights.ndim == 1: weights = weights[:, np.newaxis] - return self.normalize(weights) + norm_weights = self.normalize(weights) + return torch.tensor(norm_weights, dtype=torch.float32) class UniformWeights(BaseWeights): """Implements a uniform weight for the nodes.""" def compute(self, nodes: NodeStorage) -> np.ndarray: - return torch.ones(nodes.num_nodes) + return np.ones(nodes.num_nodes) class AreaWeights(BaseWeights): diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index 593de06..98820c0 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -19,5 +19,11 @@ def normalize(self, values: np.ndarray) -> np.ndarray: if self.norm == "unit-sum": return values / np.sum(values) if self.norm == "unit-std": - return values / np.std(values) - raise ValueError("Weight normalization must be 'l1', 'l2', 'unit-max' 'unit-sum' or 'unit-std'.") + std = np.std(values) + if std == 0: + logger.warning(f"Std. dev. of the {self.__class__.__name__} is 0. Cannot normalize.") + return values + return values / std + raise ValueError( + f"Weight normalization \"{values}\" is not valid. Options are: 'l1', 'l2', 'unit-max' 'unit-sum' or 'unit-std'." + ) From b12272d7e14a1e4282737877c24e7e0f6248ec0d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 26 Jun 2024 13:18:59 +0000 Subject: [PATCH 007/156] initial tests --- pyproject.toml | 2 ++ tests/conftest.py | 52 ++++++++++++++++++++++++++++++ tests/edges/test_attributes.py | 20 ++++++++++++ tests/edges/test_cutoff.py | 15 +++++++++ tests/edges/test_knn.py | 15 +++++++++ tests/nodes/test_npz.py | 58 ++++++++++++++++++++++++++++++++++ tests/nodes/test_weights.py | 53 +++++++++++++++++++++++++++++++ tests/nodes/test_zarr.py | 50 +++++++++++++++++++++++++++++ 8 files changed, 265 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/edges/test_attributes.py create mode 100644 tests/edges/test_cutoff.py create mode 100644 tests/edges/test_knn.py create mode 100644 tests/nodes/test_npz.py create mode 100644 tests/nodes/test_weights.py create mode 100644 tests/nodes/test_zarr.py diff --git a/pyproject.toml b/pyproject.toml index 74c3e0c..1032632 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ optional-dependencies.dev = [ "nbsphinx", "pandoc", "pytest", + "pytest-mock", "requests", "sphinx", "sphinx-argparse", @@ -83,6 +84,7 @@ optional-dependencies.docs = [ optional-dependencies.tests = [ "pytest", + "pytest-mock", ] urls.Documentation = "https://anemoi-graphs.readthedocs.io/" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..80ebfaa --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest +import torch +from torch_geometric.data import HeteroData + +lats = [-0.15, 0, 0.15] +lons = [0, 0.25, 0.5, 0.75] + + +class MockZarrDataset: + """Mock Zarr dataset with latitudes and longitudes attributes.""" + + def __init__(self, latitudes, longitudes): + self.latitudes = latitudes + self.longitudes = longitudes + self.num_nodes = len(latitudes) + + +@pytest.fixture +def mock_zarr_dataset() -> MockZarrDataset: + """Mock zarr dataset with nodes.""" + coords = 2 * torch.pi * np.array([[lat, lon] for lat in lats for lon in lons]) + return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1]) + + +@pytest.fixture +def mock_grids_path(tmp_path) -> tuple[str, int]: + """Mock grid_definition_path with files for 3 resolutions.""" + num_nodes = len(lats) * len(lons) + for resolution in ["o16", "o48", "5km5"]: + file_path = tmp_path / f"grid-{resolution}.npz" + np.savez(file_path, latitudes=np.random.rand(num_nodes), longitudes=np.random.rand(num_nodes)) + return str(tmp_path), num_nodes + + +@pytest.fixture +def graph_with_nodes() -> HeteroData: + """Graph with 12 nodes over the globe, stored in \"test_nodes\".""" + coords = np.array([[lat, lon] for lat in lats for lon in lons]) + graph = HeteroData() + graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + return graph + + +@pytest.fixture +def graph_nodes_and_edges() -> HeteroData: + """Graph with 1 set of nodes and edges.""" + coords = np.array([[lat, lon] for lat in lats for lon in lons]) + graph = HeteroData() + graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) + return graph diff --git a/tests/edges/test_attributes.py b/tests/edges/test_attributes.py new file mode 100644 index 0000000..dcd756d --- /dev/null +++ b/tests/edges/test_attributes.py @@ -0,0 +1,20 @@ +import pytest +import torch + +from anemoi.graphs.edges.attributes import DirectionalFeatures + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"]) +@pytest.mark.parametrize("luse_rotated_features", [True, False]) +def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features: bool): + """Test DirectionalFeatures compute method.""" + edge_attr_builder = DirectionalFeatures(norm=norm, luse_rotated_features=luse_rotated_features) + edge_attr = edge_attr_builder(graph_nodes_and_edges, "test_nodes", "test_nodes") + assert isinstance(edge_attr, torch.Tensor) + + +def test_fail_directional_features(graph_nodes_and_edges): + """Test DirectionalFeatures compute method.""" + edge_attr_builder = DirectionalFeatures() + with pytest.raises(AttributeError): + edge_attr_builder(graph_nodes_and_edges, "test_nodes", "unknown_nodes") diff --git a/tests/edges/test_cutoff.py b/tests/edges/test_cutoff.py new file mode 100644 index 0000000..431d52c --- /dev/null +++ b/tests/edges/test_cutoff.py @@ -0,0 +1,15 @@ +import pytest + +from anemoi.graphs.edges import CutOffEdgeBuilder + + +def test_init(): + """Test CutOffEdgeBuilder initialization.""" + CutOffEdgeBuilder("test_nodes1", "test_nodes2", 0.5) + + +@pytest.mark.parametrize("cutoff_factor", [-0.5, "hello", None]) +def test_fail_init(cutoff_factor: str): + """Test CutOffEdgeBuilder initialization with invalid cutoff.""" + with pytest.raises(AssertionError): + CutOffEdgeBuilder("test_nodes1", "test_nodes2", cutoff_factor) diff --git a/tests/edges/test_knn.py b/tests/edges/test_knn.py new file mode 100644 index 0000000..282cbf7 --- /dev/null +++ b/tests/edges/test_knn.py @@ -0,0 +1,15 @@ +import pytest + +from anemoi.graphs.edges import KNNEdgeBuilder + + +def test_init(): + """Test CutOffEdgeBuilder initialization.""" + KNNEdgeBuilder("test_nodes1", "test_nodes2", 3) + + +@pytest.mark.parametrize("num_nearest_neighbours", [-1, 2.6, "hello", None]) +def test_fail_init(num_nearest_neighbours: str): + """Test KNNEdgeBuilder initialization with invalid number of nearest neighbours.""" + with pytest.raises(AssertionError): + KNNEdgeBuilder("test_nodes1", "test_nodes2", num_nearest_neighbours) diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py new file mode 100644 index 0000000..8642e39 --- /dev/null +++ b/tests/nodes/test_npz.py @@ -0,0 +1,58 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.nodes import NPZNodes +from anemoi.graphs.nodes.weights import AreaWeights +from anemoi.graphs.nodes.weights import UniformWeights + + +@pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) +def test_init(mock_grids_path: tuple[str, int], resolution: str): + """Test NPZNodes initialization.""" + grid_definition_path, _ = mock_grids_path + node_builder = NPZNodes(resolution, grid_definition_path=grid_definition_path) + assert isinstance(node_builder, NPZNodes) + + +@pytest.mark.parametrize("resolution", ["o17", 13, "ajsnb", None]) +def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution: str): + """Test NPZNodes initialization with invalid resolution.""" + grid_definition_path, _ = mock_grids_path + with pytest.raises(FileNotFoundError): + NPZNodes(resolution, grid_definition_path=grid_definition_path) + + +def test_fail_init_wrong_path(): + """Test NPZNodes initialization with invalid path.""" + with pytest.raises(FileNotFoundError): + NPZNodes("o16", "invalid_path") + + +@pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) +def test_register_nodes(mock_grids_path: str, resolution: str): + """Test NPZNodes register correctly the nodes.""" + graph = HeteroData() + grid_definition_path, num_nodes = mock_grids_path + node_builder = NPZNodes(resolution, grid_definition_path=grid_definition_path) + + graph = node_builder.register_nodes(graph, "test_nodes") + + assert graph["test_nodes"].x is not None + assert isinstance(graph["test_nodes"].x, torch.Tensor) + assert graph["test_nodes"].x.shape == (num_nodes, 2) + assert graph["test_nodes"].node_type == "NPZNodes" + + +@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) +def test_register_weights(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): + """Test NPZNodes register correctly the weights.""" + grid_definition_path, _ = mock_grids_path + node_builder = NPZNodes("o16", grid_definition_path=grid_definition_path) + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} + + graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) + + assert graph["test_nodes"]["test_attr"] is not None + assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) + assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] diff --git a/tests/nodes/test_weights.py b/tests/nodes/test_weights.py new file mode 100644 index 0000000..db80dce --- /dev/null +++ b/tests/nodes/test_weights.py @@ -0,0 +1,53 @@ +import numpy as np +import pytest +import torch +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + + +@pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"]) +def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): + """Test NPZNodes register correctly the weights.""" + config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} + + weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + + assert weights is not None + assert isinstance(weights, torch.Tensor) + assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] + + +@pytest.mark.parametrize("norm", ["l3", "invalide"]) +def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): + """Test NPZNodes register correctly the weights.""" + config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} + + with pytest.raises(ValueError): + instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + + +def test_area_weights(graph_with_nodes: HeteroData): + """Test NPZNodes register correctly the weights.""" + config = { + "_target_": "anemoi.graphs.nodes.weights.AreaWeights", + "radius": 1.0, + "centre": np.array([0, 0, 0]), + } + + weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + + assert weights is not None + assert isinstance(weights, torch.Tensor) + assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] + + +@pytest.mark.parametrize("radius", [-1.0, "hello", None]) +def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): + config = { + "_target_": "anemoi.graphs.nodes.weights.AreaWeights", + "radius": radius, + "centre": np.array([0, 0, 0]), + } + + with pytest.raises(ValueError): + instantiate(config).get_weights(graph_with_nodes["test_nodes"]) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py new file mode 100644 index 0000000..e9a5234 --- /dev/null +++ b/tests/nodes/test_zarr.py @@ -0,0 +1,50 @@ +import pytest +import torch +import zarr +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes import nodes +from anemoi.graphs.nodes.weights import AreaWeights +from anemoi.graphs.nodes.weights import UniformWeights + + +def test_init(mocker, mock_zarr_dataset): + """Test ZarrNodes initialization.""" + mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset) + node_builder = nodes.ZarrNodes("dataset.zarr") + assert isinstance(node_builder, nodes.BaseNodeBuilder) + assert isinstance(node_builder, nodes.ZarrNodes) + + +def test_fail_init(): + """Test ZarrNodes initialization with invalid resolution.""" + with pytest.raises(zarr.errors.PathNotFoundError): + nodes.ZarrNodes("invalid_path.zarr") + + +def test_register_nodes(mocker, mock_zarr_dataset): + """Test ZarrNodes register correctly the nodes.""" + mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset) + node_builder = nodes.ZarrNodes("dataset.zarr") + graph = HeteroData() + + graph = node_builder.register_nodes(graph, "test_nodes") + + assert graph["test_nodes"].x is not None + assert isinstance(graph["test_nodes"].x, torch.Tensor) + assert graph["test_nodes"].x.shape == (node_builder.ds.num_nodes, 2) + assert graph["test_nodes"].node_type == "ZarrNodes" + + +@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) +def test_register_weights(mocker, graph_with_nodes: HeteroData, attr_class): + """Test ZarrNodes register correctly the weights.""" + mocker.patch.object(nodes, "open_dataset", return_value=None) + node_builder = nodes.ZarrNodes("dataset.zarr") + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} + + graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) + + assert graph["test_nodes"]["test_attr"] is not None + assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) + assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] From cce5ea6f51e6c98245dace0d8f692d24a21c16bf Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 26 Jun 2024 14:22:50 +0000 Subject: [PATCH 008/156] feat: initial version of AttributeBuilder --- src/anemoi/graphs/edges/attributes.py | 33 +++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 0aa297d..8c86bca 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -3,14 +3,47 @@ from abc import abstractmethod from typing import Optional +import torch +from anemoi.utils.config import DotDict import numpy as np from torch_geometric.data import HeteroData +from hydra.utils import instantiate from anemoi.graphs.edges.directional import directional_edge_features from anemoi.graphs.normalizer import NormalizerMixin logger = logging.getLogger(__name__) +class AttributeBuilder(): + + def transform(self, graph: HeteroData, graph_config: DotDict): + + for name, nodes_cfg in graph_config.nodes.items(): + graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {})) + for edges_cfg in graph_config.edges: + graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) + return graph + + def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict): + assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph." + for attr_name, attr_cfg in node_config.items(): + graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) + return graph + + def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict): + + for attr_name, attr_cfg in edge_config.items(): + attr_values = instantiate(attr_cfg).compute(graph, src_name, dst_name) + graph = self.register_edge_attribute(graph, src_name, dst_name, attr_name, attr_values) + return graph + + def register_edge_attribute(self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor): + num_edges = graph[(src_name, "to", dst_name)].num_edges + assert ( attr_values.shape[0] == num_edges), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})." + + graph[(src_name, "to", dst_name)][attr_name] = attr_values + return graph + class BaseEdgeAttribute(ABC, NormalizerMixin): norm: Optional[str] = None From 9ba039149d2e9f9a962fdb51e51674ad4b8e6adc Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 26 Jun 2024 14:40:06 +0000 Subject: [PATCH 009/156] refactor: separate into node edge attribute builders --- src/anemoi/graphs/edges/attributes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 8c86bca..6b57e70 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -14,15 +14,12 @@ logger = logging.getLogger(__name__) -class AttributeBuilder(): +class NodeAttributeBuilder(): def transform(self, graph: HeteroData, graph_config: DotDict): for name, nodes_cfg in graph_config.nodes.items(): graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {})) - for edges_cfg in graph_config.edges: - graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) - return graph def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict): assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph." @@ -30,6 +27,13 @@ def register_node_attributes(self, graph: HeteroData, node_name: str, node_confi graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) return graph +class EdgeAttributeBuilder(): + + def transform(self, graph: HeteroData, graph_config: DotDict): + for edges_cfg in graph_config.edges: + graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) + return graph + def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict): for attr_name, attr_cfg in edge_config.items(): From 9a47184ac06a5f763852cda4c852b907c77213de Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 08:05:05 +0000 Subject: [PATCH 010/156] feat: edge_length moved to edges/attributes.py --- src/anemoi/graphs/__init__.py | 2 +- src/anemoi/graphs/create.py | 21 +++++- src/anemoi/graphs/edges/attributes.py | 93 ++++++++++++++------------ src/anemoi/graphs/edges/connections.py | 15 +---- src/anemoi/graphs/normalizer.py | 2 + src/anemoi/graphs/utils.py | 22 ++++++ tests/nodes/test_weights.py | 33 +++------ tests/test_normalizer.py | 52 ++++++++++++++ 8 files changed, 159 insertions(+), 81 deletions(-) create mode 100644 tests/test_normalizer.py diff --git a/src/anemoi/graphs/__init__.py b/src/anemoi/graphs/__init__.py index 80d19fb..715b8a4 100644 --- a/src/anemoi/graphs/__init__.py +++ b/src/anemoi/graphs/__init__.py @@ -7,4 +7,4 @@ from ._version import __version__ -earth_radius = 6371.0 # km +EARTH_RADIUS = 6371.0 # km diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index d2b803d..da1692e 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -10,13 +10,25 @@ def generate_graph(graph_config: DotDict) -> HeteroData: + """Generate a graph from a configuration. + + Parameters + ---------- + graph_config : DotDict + Configuration for the nodes and edges (and its attributes). + + Returns + ------- + HeteroData + Graph. + """ graph = HeteroData() for name, nodes_cfg in graph_config.nodes.items(): - graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {})) + graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {})) for edges_cfg in graph_config.edges: - graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) + graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) return graph @@ -66,3 +78,8 @@ def _path_readable(self) -> bool: return True except FileNotFoundError: return False + + +if __name__ == "__main__": + creator = GraphCreator(config="/home/ecm1924/GitRepos/anemoi-graphs/recipe.yaml", path="graph.pt") + creator.create() diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 6b57e70..9e7509f 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -1,74 +1,81 @@ import logging from abc import ABC from abc import abstractmethod +from dataclasses import dataclass from typing import Optional -import torch -from anemoi.utils.config import DotDict import numpy as np +import torch +from scipy.sparse import coo_matrix +from sklearn.preprocessing import normalize from torch_geometric.data import HeteroData -from hydra.utils import instantiate from anemoi.graphs.edges.directional import directional_edge_features from anemoi.graphs.normalizer import NormalizerMixin +from anemoi.graphs.utils import haversine_distance logger = logging.getLogger(__name__) -class NodeAttributeBuilder(): - - def transform(self, graph: HeteroData, graph_config: DotDict): - - for name, nodes_cfg in graph_config.nodes.items(): - graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {})) - - def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict): - assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph." - for attr_name, attr_cfg in node_config.items(): - graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) - return graph - -class EdgeAttributeBuilder(): - - def transform(self, graph: HeteroData, graph_config: DotDict): - for edges_cfg in graph_config.edges: - graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) - return graph - - def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict): - - for attr_name, attr_cfg in edge_config.items(): - attr_values = instantiate(attr_cfg).compute(graph, src_name, dst_name) - graph = self.register_edge_attribute(graph, src_name, dst_name, attr_name, attr_values) - return graph - - def register_edge_attribute(self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor): - num_edges = graph[(src_name, "to", dst_name)].num_edges - assert ( attr_values.shape[0] == num_edges), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})." - - graph[(src_name, "to", dst_name)][attr_name] = attr_values - return graph - +@dataclass class BaseEdgeAttribute(ABC, NormalizerMixin): norm: Optional[str] = None @abstractmethod - def compute(self, graph: HeteroData, *args, **kwargs): ... + def compute(self, graph: HeteroData, *args, **kwargs) -> np.ndarray: ... + + def post_process(self, values: np.ndarray) -> torch.Tensor: + return torch.tensor(values) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> torch.Tensor: values = self.compute(*args, **kwargs) - if values.ndim == 1: - values = values[:, np.newaxis] - return self.normalize(values) + normed_values = self.normalize(values) + if normed_values.ndim == 1: + normed_values = normed_values[:, np.newaxis] + return self.post_process(normed_values) +@dataclass class DirectionalFeatures(BaseEdgeAttribute): + """Compute directional features for edges.""" + norm: Optional[str] = None luse_rotated_features: bool = False - def compute(self, graph: HeteroData, src_name: str, dst_name: str): + def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> torch.Tensor: edge_index = graph[(src_name, "to", dst_name)].edge_index src_coords = graph[src_name].x.numpy()[edge_index[0]].T dst_coords = graph[dst_name].x.numpy()[edge_index[1]].T edge_dirs = directional_edge_features(src_coords, dst_coords, self.luse_rotated_features).T return edge_dirs + + +@dataclass +class HaversineDistance(BaseEdgeAttribute): + """Edge length feature.""" + + norm: str = "l1" + invert: bool = True + + def compute(self, graph: HeteroData, src_name: str, dst_name: str): + """Compute haversine distance (in kilometers) between nodes connected by edges.""" + assert src_name in graph.node_types, f"Node {src_name} not found in graph." + assert dst_name in graph.node_types, f"Node {dst_name} not found in graph." + edge_index = graph[(src_name, "to", dst_name)].edge_index + src_coords = graph[src_name].x.numpy()[edge_index[0]] + dst_coords = graph[dst_name].x.numpy()[edge_index[1]] + edge_lengths = haversine_distance(src_coords, dst_coords) + return coo_matrix((edge_lengths, (edge_index[1], edge_index[0]))) + + def normalize(self, values) -> np.ndarray: + """Normalize the edge length. + + This method scales the edge lengths to a unit norm, computing the norms + for each source node (axis=1). + """ + return normalize(values, norm="l1", axis=1).data + + def post_process(self, values: np.ndarray) -> torch.Tensor: + if self.invert: + values = 1 - values + return super().post_process(values) diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/connections.py index 6bf057e..49080ce 100644 --- a/src/anemoi/graphs/edges/connections.py +++ b/src/anemoi/graphs/edges/connections.py @@ -7,11 +7,10 @@ from anemoi.utils.config import DotDict from hydra.utils import instantiate from sklearn.neighbors import NearestNeighbors -from sklearn.preprocessing import normalize from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage -from anemoi.graphs import earth_radius +from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.utils import get_grid_reference_distance logger = logging.getLogger(__name__) @@ -54,13 +53,9 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) - # Compute adjacency matrix. adjmat = self.get_adj_matrix(src_nodes, dst_nodes) - # Normalize adjacency matrix. - adjmat_norm = self.normalize_adjmat(adjmat) - # Add edges to the graph and register normed distance. graph = self.register_edges(graph, adjmat.col, adjmat.row) - self.register_edge_attribute(graph, "normed_dist", adjmat_norm.data) if attrs_config is not None: for attr_name, attr_cfg in attrs_config.items(): attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name) @@ -68,12 +63,6 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) - return graph - def normalize_adjmat(self, adjmat): - """Normalize a sparse adjacency matrix.""" - adjmat_norm = normalize(adjmat, norm="l1", axis=1) - adjmat_norm.data = 1.0 - adjmat_norm.data - return adjmat_norm - class KNNEdgeBuilder(BaseEdgeBuilder): """Computes KNN based edges and adds them to the graph.""" @@ -124,7 +113,7 @@ def prepare_node_data(self, graph: HeteroData): return super().prepare_node_data(graph) def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): - logger.debug("Using cut-off radius of %.1f km.", self.radius * earth_radius) + logger.debug("Using cut-off radius of %.1f km.", self.radius * EARTH_RADIUS) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) nearest_neighbour.fit(src_nodes.x) diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index 98820c0..5b3edcd 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -6,6 +6,8 @@ class NormalizerMixin: + """Mixin class for normalizing attributes.""" + def normalize(self, values: np.ndarray) -> np.ndarray: if self.norm is None: logger.debug("Node weights are not normalized.") diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py index 1a25134..f655e8d 100644 --- a/src/anemoi/graphs/utils.py +++ b/src/anemoi/graphs/utils.py @@ -109,3 +109,25 @@ def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int: if mask.any(): return int(torch.where(mask)[0]) return -1 + + +def haversine_distance(src_coords: np.ndarray, dst_coords: np.ndarray) -> np.ndarray: + """Haversine distance. + + Parameters + ---------- + src_coords : np.ndarray of shape (N, 2) + Source coordinates in radians. + dst_coords : np.ndarray of shape (N, 2) + Destination coordinates in radians. + + Returns + ------- + np.ndarray of shape (N,) + Haversine distance between source and destination coordinates. + """ + dlat = dst_coords[:, 0] - src_coords[:, 0] + dlon = dst_coords[:, 1] - src_coords[:, 1] + a = np.sin(dlat / 2) ** 2 + np.cos(src_coords[:, 0]) * np.cos(dst_coords[:, 0]) * np.sin(dlon / 2) ** 2 + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) + return c diff --git a/tests/nodes/test_weights.py b/tests/nodes/test_weights.py index db80dce..71e54fa 100644 --- a/tests/nodes/test_weights.py +++ b/tests/nodes/test_weights.py @@ -1,16 +1,16 @@ -import numpy as np import pytest import torch -from hydra.utils import instantiate from torch_geometric.data import HeteroData +from anemoi.graphs.nodes.weights import AreaWeights +from anemoi.graphs.nodes.weights import UniformWeights + @pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"]) def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): """Test NPZNodes register correctly the weights.""" - config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} - - weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = UniformWeights(norm=norm) + weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) assert weights is not None assert isinstance(weights, torch.Tensor) @@ -20,21 +20,15 @@ def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): @pytest.mark.parametrize("norm", ["l3", "invalide"]) def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): """Test NPZNodes register correctly the weights.""" - config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} - with pytest.raises(ValueError): - instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = UniformWeights(norm=norm) + node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) def test_area_weights(graph_with_nodes: HeteroData): """Test NPZNodes register correctly the weights.""" - config = { - "_target_": "anemoi.graphs.nodes.weights.AreaWeights", - "radius": 1.0, - "centre": np.array([0, 0, 0]), - } - - weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = AreaWeights() + weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) assert weights is not None assert isinstance(weights, torch.Tensor) @@ -43,11 +37,6 @@ def test_area_weights(graph_with_nodes: HeteroData): @pytest.mark.parametrize("radius", [-1.0, "hello", None]) def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): - config = { - "_target_": "anemoi.graphs.nodes.weights.AreaWeights", - "radius": radius, - "centre": np.array([0, 0, 0]), - } - with pytest.raises(ValueError): - instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = AreaWeights(radius=radius) + node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py new file mode 100644 index 0000000..2654c0c --- /dev/null +++ b/tests/test_normalizer.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest + +from anemoi.graphs.normalizer import NormalizerMixin + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"]) +def test_normalizer(norm: str): + """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): + def __init__(self, norm): + self.norm = norm + + def __call__(self, data): + return self.normalize(data) + + normalizer = Normalizer(norm=norm) + data = np.random.rand(10, 5) + normalized_data = normalizer(data) + assert isinstance(normalized_data, np.ndarray) + assert normalized_data.shape == data.shape + + +@pytest.mark.parametrize("norm", ["l3", "invalid"]) +def test_normalizer_wrong_norm(norm: str): + """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): + def __init__(self, norm: str): + self.norm = norm + + def __call__(self, data): + return self.normalize(data) + + with pytest.raises(ValueError): + normalizer = Normalizer(norm=norm) + data = np.random.rand(10, 5) + normalizer(data) + + +def test_normalizer_wrong_inheritance(): + """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): + def __init__(self, attr): + self.attr = attr + + def __call__(self, data): + return self.normalize(data) + + with pytest.raises(AttributeError): + normalizer = Normalizer(attr="attr_name") + data = np.random.rand(10, 5) + normalizer(data) From 384adc7f7f4a85ee23cde85c00ada2928514e129 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 08:05:34 +0000 Subject: [PATCH 011/156] remove __init__ --- src/anemoi/graphs/create.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index da1692e..7a1d2a9 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -78,8 +78,3 @@ def _path_readable(self) -> bool: return True except FileNotFoundError: return False - - -if __name__ == "__main__": - creator = GraphCreator(config="/home/ecm1924/GitRepos/anemoi-graphs/recipe.yaml", path="graph.pt") - creator.create() From 0bc176cc909ab43d4ec1c1b9cddb988667c6671f Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 08:10:59 +0000 Subject: [PATCH 012/156] feat: test edge builders --- tests/edges/test_cutoff.py | 7 +++++++ tests/edges/test_knn.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/tests/edges/test_cutoff.py b/tests/edges/test_cutoff.py index 431d52c..844cf87 100644 --- a/tests/edges/test_cutoff.py +++ b/tests/edges/test_cutoff.py @@ -13,3 +13,10 @@ def test_fail_init(cutoff_factor: str): """Test CutOffEdgeBuilder initialization with invalid cutoff.""" with pytest.raises(AssertionError): CutOffEdgeBuilder("test_nodes1", "test_nodes2", cutoff_factor) + + +def test_cutoff(graph_with_nodes): + """Test CutOffEdgeBuilder.""" + builder = CutOffEdgeBuilder("test_nodes", "test_nodes", 0.5) + graph = builder.transform(graph_with_nodes) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types diff --git a/tests/edges/test_knn.py b/tests/edges/test_knn.py index 282cbf7..aee529a 100644 --- a/tests/edges/test_knn.py +++ b/tests/edges/test_knn.py @@ -13,3 +13,10 @@ def test_fail_init(num_nearest_neighbours: str): """Test KNNEdgeBuilder initialization with invalid number of nearest neighbours.""" with pytest.raises(AssertionError): KNNEdgeBuilder("test_nodes1", "test_nodes2", num_nearest_neighbours) + + +def test_knn(graph_with_nodes): + """Test KNNEdgeBuilder.""" + builder = KNNEdgeBuilder("test_nodes", "test_nodes", 3) + graph = builder.transform(graph_with_nodes) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types From d16934ba69884ee453b227d9c9c58bbfed730d19 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 08:13:51 +0000 Subject: [PATCH 013/156] add blank lines --- tests/test_normalizer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py index 2654c0c..25941b8 100644 --- a/tests/test_normalizer.py +++ b/tests/test_normalizer.py @@ -7,10 +7,11 @@ @pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"]) def test_normalizer(norm: str): """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): def __init__(self, norm): self.norm = norm - + def __call__(self, data): return self.normalize(data) @@ -24,10 +25,11 @@ def __call__(self, data): @pytest.mark.parametrize("norm", ["l3", "invalid"]) def test_normalizer_wrong_norm(norm: str): """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): def __init__(self, norm: str): self.norm = norm - + def __call__(self, data): return self.normalize(data) @@ -39,10 +41,11 @@ def __call__(self, data): def test_normalizer_wrong_inheritance(): """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): def __init__(self, attr): self.attr = attr - + def __call__(self, data): return self.normalize(data) From 0f82ea7413be23f39607d2c0a0ff4ffd6f4ca656 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 09:59:26 +0000 Subject: [PATCH 014/156] dep: hydra-core --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 1032632..5695764 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dynamic = [ dependencies = [ "anemoi-datasets[data]>=0.3.3", "anemoi-utils>=0.3.6", + "hydra-core==1.3", "torch>=2.2", "torch-geometric>=2.3.1,<2.5", ] From a9c5adabe20cd3f64ef42e758636c41a5fa56f7f Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 14:54:17 +0000 Subject: [PATCH 015/156] bugfix (encoder edge lengths) + refector --- src/anemoi/graphs/create.py | 41 +++++++------------ src/anemoi/graphs/edges/__init__.py | 4 +- src/anemoi/graphs/edges/attributes.py | 16 ++------ .../edges/{connections.py => builder.py} | 0 src/anemoi/graphs/nodes/__init__.py | 4 +- .../graphs/nodes/{nodes.py => builder.py} | 0 tests/conftest.py | 40 ++++++++++++++++++ tests/nodes/test_npz.py | 2 +- tests/nodes/test_zarr.py | 20 ++++----- tests/test_graphs.py | 22 ++++++++-- 10 files changed, 90 insertions(+), 59 deletions(-) rename src/anemoi/graphs/edges/{connections.py => builder.py} (100%) rename src/anemoi/graphs/nodes/{nodes.py => builder.py} (100%) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 7a1d2a9..53861a9 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -9,31 +9,9 @@ logger = logging.getLogger(__name__) -def generate_graph(graph_config: DotDict) -> HeteroData: - """Generate a graph from a configuration. - - Parameters - ---------- - graph_config : DotDict - Configuration for the nodes and edges (and its attributes). - - Returns - ------- - HeteroData - Graph. - """ - graph = HeteroData() - - for name, nodes_cfg in graph_config.nodes.items(): - graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {})) - - for edges_cfg in graph_config.edges: - graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) - - return graph - - class GraphCreator: + """Graph creator.""" + def __init__( self, path, @@ -55,9 +33,18 @@ def init(self): if self._path_readable() and not self.overwrite: raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") - def load(self) -> HeteroData: + def generate_graph(self) -> HeteroData: config = DotDict.from_file(self.config) - graph = generate_graph(config) + + graph = HeteroData() + for name, nodes_cfg in config.nodes.items(): + graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {})) + + for edges_cfg in config.edges: + graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform( + graph, edges_cfg.get("attributes", {}) + ) + return graph def save(self, graph: HeteroData) -> None: @@ -67,7 +54,7 @@ def save(self, graph: HeteroData) -> None: def create(self): self.init() - graph = self.load() + graph = self.generate_graph() self.save(graph) def _path_readable(self) -> bool: diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 29875d0..edd07db 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -1,4 +1,4 @@ -from .connections import CutOffEdgeBuilder -from .connections import KNNEdgeBuilder +from .builder import CutOffEdgeBuilder +from .builder import KNNEdgeBuilder __all__ = ["KNNEdgeBuilder", "CutOffEdgeBuilder"] diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 9e7509f..47787b3 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -6,8 +6,6 @@ import numpy as np import torch -from scipy.sparse import coo_matrix -from sklearn.preprocessing import normalize from torch_geometric.data import HeteroData from anemoi.graphs.edges.directional import directional_edge_features @@ -51,13 +49,13 @@ def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> torch.Tens @dataclass -class HaversineDistance(BaseEdgeAttribute): +class EdgeLength(BaseEdgeAttribute): """Edge length feature.""" norm: str = "l1" invert: bool = True - def compute(self, graph: HeteroData, src_name: str, dst_name: str): + def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> np.ndarray: """Compute haversine distance (in kilometers) between nodes connected by edges.""" assert src_name in graph.node_types, f"Node {src_name} not found in graph." assert dst_name in graph.node_types, f"Node {dst_name} not found in graph." @@ -65,15 +63,7 @@ def compute(self, graph: HeteroData, src_name: str, dst_name: str): src_coords = graph[src_name].x.numpy()[edge_index[0]] dst_coords = graph[dst_name].x.numpy()[edge_index[1]] edge_lengths = haversine_distance(src_coords, dst_coords) - return coo_matrix((edge_lengths, (edge_index[1], edge_index[0]))) - - def normalize(self, values) -> np.ndarray: - """Normalize the edge length. - - This method scales the edge lengths to a unit norm, computing the norms - for each source node (axis=1). - """ - return normalize(values, norm="l1", axis=1).data + return edge_lengths def post_process(self, values: np.ndarray) -> torch.Tensor: if self.invert: diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/builder.py similarity index 100% rename from src/anemoi/graphs/edges/connections.py rename to src/anemoi/graphs/edges/builder.py diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 5458495..beecc98 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,4 @@ -from .nodes import NPZNodes -from .nodes import ZarrNodes +from .builder import NPZNodes +from .builder import ZarrNodes __all__ = ["ZarrNodes", "NPZNodes"] diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/builder.py similarity index 100% rename from src/anemoi/graphs/nodes/nodes.py rename to src/anemoi/graphs/nodes/builder.py diff --git a/tests/conftest.py b/tests/conftest.py index 80ebfaa..b6b8ba0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +import yaml from torch_geometric.data import HeteroData lats = [-0.15, 0, 0.15] @@ -50,3 +51,42 @@ def graph_nodes_and_edges() -> HeteroData: graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) return graph + + +@pytest.fixture +def config_file(tmp_path) -> tuple[str, str]: + """Mock grid_definition_path with files for 3 resolutions.""" + cfg = { + "nodes": { + "test_nodes": { + "node_builder": { + "_target_": "anemoi.graphs.nodes.NPZNodes", + "grid_definition_path": str(tmp_path), + "resolution": "o16", + }, + } + }, + "edges": [ + { + "nodes": {"src_name": "test_nodes", "dst_name": "test_nodes"}, + "edge_builder": { + "_target_": "anemoi.graphs.edges.KNNEdgeBuilder", + "num_nearest_neighbours": 3, + }, + "attributes": { + "dist_norm": { + "_target_": "anemoi.graphs.edges.attributes.EdgeLength", + "norm": "l1", + "invert": True, + }, + "directional_features": {"_target_": "anemoi.graphs.edges.attributes.DirectionalFeatures"}, + }, + }, + ], + } + file_name = "config.yaml" + + with (tmp_path / file_name).open("w") as file: + yaml.dump(cfg, file) + + return tmp_path, file_name diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 8642e39..ebe88d9 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -2,7 +2,7 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes.nodes import NPZNodes +from anemoi.graphs.nodes.builder import NPZNodes from anemoi.graphs.nodes.weights import AreaWeights from anemoi.graphs.nodes.weights import UniformWeights diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index e9a5234..ddf804f 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -3,29 +3,29 @@ import zarr from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import nodes +from anemoi.graphs.nodes import builder from anemoi.graphs.nodes.weights import AreaWeights from anemoi.graphs.nodes.weights import UniformWeights def test_init(mocker, mock_zarr_dataset): """Test ZarrNodes initialization.""" - mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset) - node_builder = nodes.ZarrNodes("dataset.zarr") - assert isinstance(node_builder, nodes.BaseNodeBuilder) - assert isinstance(node_builder, nodes.ZarrNodes) + mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) + node_builder = builder.ZarrNodes("dataset.zarr") + assert isinstance(node_builder, builder.BaseNodeBuilder) + assert isinstance(node_builder, builder.ZarrNodes) def test_fail_init(): """Test ZarrNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): - nodes.ZarrNodes("invalid_path.zarr") + builder.ZarrNodes("invalid_path.zarr") def test_register_nodes(mocker, mock_zarr_dataset): """Test ZarrNodes register correctly the nodes.""" - mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset) - node_builder = nodes.ZarrNodes("dataset.zarr") + mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) + node_builder = builder.ZarrNodes("dataset.zarr") graph = HeteroData() graph = node_builder.register_nodes(graph, "test_nodes") @@ -39,8 +39,8 @@ def test_register_nodes(mocker, mock_zarr_dataset): @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_weights(mocker, graph_with_nodes: HeteroData, attr_class): """Test ZarrNodes register correctly the weights.""" - mocker.patch.object(nodes, "open_dataset", return_value=None) - node_builder = nodes.ZarrNodes("dataset.zarr") + mocker.patch.object(builder, "open_dataset", return_value=None) + node_builder = builder.ZarrNodes("dataset.zarr") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 846ee89..0ceb171 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -6,9 +6,23 @@ # nor does it submit to any jurisdiction. -def test_graphs(): - pass +from pathlib import Path +import torch +from torch_geometric.data import HeteroData -if __name__ == "__main__": - test_graphs() +from anemoi.graphs import create + + +def test_graphs(config_file: tuple[Path, str], mock_grids_path: tuple[str, int]): + """Test GraphCreator workflow.""" + tmp_path, config_name = config_file + graph_path = tmp_path / "graph.pt" + config_path = tmp_path / config_name + + create.GraphCreator(graph_path, config_path).create() + + graph = torch.load(graph_path) + assert isinstance(graph, HeteroData) + assert "test_nodes" in graph.node_types + assert ("test_nodes", "to", "test_nodes") in graph.edge_types From 66ef5dc96fcb484c19a36139e57ae7416052dc91 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 14:59:30 +0000 Subject: [PATCH 016/156] deps: == to >= --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5695764..cb5bb7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dynamic = [ dependencies = [ "anemoi-datasets[data]>=0.3.3", "anemoi-utils>=0.3.6", - "hydra-core==1.3", + "hydra-core>=1.3", "torch>=2.2", "torch-geometric>=2.3.1,<2.5", ] From fcd0b74edc0743a0b62b226697bd17bbf6d97e7e Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 15:16:50 +0000 Subject: [PATCH 017/156] rename node builder classes --- src/anemoi/graphs/nodes/__init__.py | 6 +++--- src/anemoi/graphs/nodes/builder.py | 4 ++-- tests/conftest.py | 2 +- tests/nodes/test_npz.py | 14 +++++++------- tests/nodes/test_zarr.py | 10 +++++----- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index beecc98..f6d8e4d 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,4 @@ -from .builder import NPZNodes -from .builder import ZarrNodes +from .builder import NPZFileNodeBuilder +from .builder import ZarrDatasetNodeBuilder -__all__ = ["ZarrNodes", "NPZNodes"] +__all__ = ["ZarrDatasetNodeBuilder", "NPZFileNodeBuilder"] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 886125e..774a222 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -40,7 +40,7 @@ def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> Heter return graph -class ZarrNodes(BaseNodeBuilder): +class ZarrDatasetNodeBuilder(BaseNodeBuilder): """Nodes from Zarr dataset.""" def __init__(self, dataset: DotDict) -> None: @@ -51,7 +51,7 @@ def get_coordinates(self) -> torch.Tensor: return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) -class NPZNodes(BaseNodeBuilder): +class NPZFileNodeBuilder(BaseNodeBuilder): """Nodes from NPZ defined grids.""" def __init__(self, resolution: str, grid_definition_path: str) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index b6b8ba0..4fffb34 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,7 +60,7 @@ def config_file(tmp_path) -> tuple[str, str]: "nodes": { "test_nodes": { "node_builder": { - "_target_": "anemoi.graphs.nodes.NPZNodes", + "_target_": "anemoi.graphs.nodes.NPZFileNodeBuilder", "grid_definition_path": str(tmp_path), "resolution": "o16", }, diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index ebe88d9..906e6d1 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -2,7 +2,7 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes.builder import NPZNodes +from anemoi.graphs.nodes.builder import NPZFileNodeBuilder from anemoi.graphs.nodes.weights import AreaWeights from anemoi.graphs.nodes.weights import UniformWeights @@ -11,8 +11,8 @@ def test_init(mock_grids_path: tuple[str, int], resolution: str): """Test NPZNodes initialization.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZNodes(resolution, grid_definition_path=grid_definition_path) - assert isinstance(node_builder, NPZNodes) + node_builder = NPZFileNodeBuilder(resolution, grid_definition_path=grid_definition_path) + assert isinstance(node_builder, NPZFileNodeBuilder) @pytest.mark.parametrize("resolution", ["o17", 13, "ajsnb", None]) @@ -20,13 +20,13 @@ def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution """Test NPZNodes initialization with invalid resolution.""" grid_definition_path, _ = mock_grids_path with pytest.raises(FileNotFoundError): - NPZNodes(resolution, grid_definition_path=grid_definition_path) + NPZFileNodeBuilder(resolution, grid_definition_path=grid_definition_path) def test_fail_init_wrong_path(): """Test NPZNodes initialization with invalid path.""" with pytest.raises(FileNotFoundError): - NPZNodes("o16", "invalid_path") + NPZFileNodeBuilder("o16", "invalid_path") @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) @@ -34,7 +34,7 @@ def test_register_nodes(mock_grids_path: str, resolution: str): """Test NPZNodes register correctly the nodes.""" graph = HeteroData() grid_definition_path, num_nodes = mock_grids_path - node_builder = NPZNodes(resolution, grid_definition_path=grid_definition_path) + node_builder = NPZFileNodeBuilder(resolution, grid_definition_path=grid_definition_path) graph = node_builder.register_nodes(graph, "test_nodes") @@ -48,7 +48,7 @@ def test_register_nodes(mock_grids_path: str, resolution: str): def test_register_weights(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): """Test NPZNodes register correctly the weights.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZNodes("o16", grid_definition_path=grid_definition_path) + node_builder = NPZFileNodeBuilder("o16", grid_definition_path=grid_definition_path) config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index ddf804f..66a9d1a 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -11,21 +11,21 @@ def test_init(mocker, mock_zarr_dataset): """Test ZarrNodes initialization.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodeBuilder("dataset.zarr") assert isinstance(node_builder, builder.BaseNodeBuilder) - assert isinstance(node_builder, builder.ZarrNodes) + assert isinstance(node_builder, builder.ZarrDatasetNodeBuilder) def test_fail_init(): """Test ZarrNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): - builder.ZarrNodes("invalid_path.zarr") + builder.ZarrDatasetNodeBuilder("invalid_path.zarr") def test_register_nodes(mocker, mock_zarr_dataset): """Test ZarrNodes register correctly the nodes.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodeBuilder("dataset.zarr") graph = HeteroData() graph = node_builder.register_nodes(graph, "test_nodes") @@ -40,7 +40,7 @@ def test_register_nodes(mocker, mock_zarr_dataset): def test_register_weights(mocker, graph_with_nodes: HeteroData, attr_class): """Test ZarrNodes register correctly the weights.""" mocker.patch.object(builder, "open_dataset", return_value=None) - node_builder = builder.ZarrNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodeBuilder("dataset.zarr") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) From b28b0ff0751743d0861955dd2a9a074da448de0b Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 15:42:37 +0000 Subject: [PATCH 018/156] fix: tests --- tests/nodes/test_npz.py | 2 +- tests/nodes/test_zarr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 906e6d1..61a5907 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -41,7 +41,7 @@ def test_register_nodes(mock_grids_path: str, resolution: str): assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) assert graph["test_nodes"].x.shape == (num_nodes, 2) - assert graph["test_nodes"].node_type == "NPZNodes" + assert graph["test_nodes"].node_type == "NPZFileNodeBuilder" @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index 66a9d1a..190e207 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -33,7 +33,7 @@ def test_register_nodes(mocker, mock_zarr_dataset): assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) assert graph["test_nodes"].x.shape == (node_builder.ds.num_nodes, 2) - assert graph["test_nodes"].node_type == "ZarrNodes" + assert graph["test_nodes"].node_type == "ZarrDatasetNodeBuilder" @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) From ec8c9c5bf6c1b796f1f3176b57f40bbff9a684b1 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 28 Jun 2024 14:52:01 +0000 Subject: [PATCH 019/156] feat: support path and dict for `config` argument --- src/anemoi/graphs/create.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 53861a9..14e4a12 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -21,26 +21,26 @@ def __init__( overwrite=False, **kwargs, ): + if isinstance(config, str) or isinstance(config, os.PathLike): + self.config = DotDict.from_file(self.config) + else: + self.config = config + self.path = path # Output path - self.config = config self.cache = cache self.print = print self.overwrite = overwrite def init(self): - assert os.path.exists(self.config), f"Path {self.config} does not exist." - if self._path_readable() and not self.overwrite: raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") def generate_graph(self) -> HeteroData: - config = DotDict.from_file(self.config) - graph = HeteroData() - for name, nodes_cfg in config.nodes.items(): + for name, nodes_cfg in self.config.nodes.items(): graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {})) - for edges_cfg in config.edges: + for edges_cfg in self.config.edges: graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform( graph, edges_cfg.get("attributes", {}) ) @@ -52,10 +52,11 @@ def save(self, graph: HeteroData) -> None: torch.save(graph, self.path) self.print(f"Graph saved at {self.path}.") - def create(self): + def create(self) -> HeteroData: self.init() graph = self.generate_graph() self.save(graph) + return graph def _path_readable(self) -> bool: import torch From 9b9d805b6127224be0c998aec53234f47d3d7922 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 28 Jun 2024 15:15:41 +0000 Subject: [PATCH 020/156] fix: error --- src/anemoi/graphs/create.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 14e4a12..57b60bc 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -22,7 +22,7 @@ def __init__( **kwargs, ): if isinstance(config, str) or isinstance(config, os.PathLike): - self.config = DotDict.from_file(self.config) + self.config = DotDict.from_file(config) else: self.config = config From 2b67bf3388b7915a8d28eb654be9eedfdfce0756 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 10:12:40 +0000 Subject: [PATCH 021/156] refactor: naming --- src/anemoi/graphs/edges/__init__.py | 7 +++--- src/anemoi/graphs/edges/builder.py | 4 ++-- src/anemoi/graphs/nodes/__init__.py | 6 ++--- .../nodes/{weights.py => attributes.py} | 0 src/anemoi/graphs/nodes/builder.py | 4 ++-- tests/conftest.py | 4 ++-- tests/edges/test_cutoff.py | 8 +++---- tests/edges/test_knn.py | 8 +++---- ...est_weights.py => test_node_attributes.py} | 4 ++-- tests/nodes/test_npz.py | 24 +++++++++---------- tests/nodes/test_zarr.py | 20 ++++++++-------- 11 files changed, 45 insertions(+), 44 deletions(-) rename src/anemoi/graphs/nodes/{weights.py => attributes.py} (100%) rename tests/nodes/{test_weights.py => test_node_attributes.py} (93%) diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index edd07db..19d48db 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -1,4 +1,5 @@ -from .builder import CutOffEdgeBuilder -from .builder import KNNEdgeBuilder +from .builder import CutOffEdges +from .builder import KNNEdges + +__all__ = ["KNNEdges", "CutOffEdges"] -__all__ = ["KNNEdgeBuilder", "CutOffEdgeBuilder"] diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 49080ce..759502d 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -64,7 +64,7 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) - return graph -class KNNEdgeBuilder(BaseEdgeBuilder): +class KNNEdges(BaseEdgeBuilder): """Computes KNN based edges and adds them to the graph.""" def __init__(self, src_name: str, dst_name: str, num_nearest_neighbours: int): @@ -92,7 +92,7 @@ def get_adj_matrix(self, src_nodes: np.ndarray, dst_nodes: np.ndarray): return adj_matrix -class CutOffEdgeBuilder(BaseEdgeBuilder): +class CutOffEdges(BaseEdgeBuilder): """Computes cut-off based edges and adds them to the graph.""" def __init__(self, src_name: str, dst_name: str, cutoff_factor: float): diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index f6d8e4d..737f27f 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,4 @@ -from .builder import NPZFileNodeBuilder -from .builder import ZarrDatasetNodeBuilder +from .builder import NPZFileNodes +from .builder import ZarrDatasetNodes -__all__ = ["ZarrDatasetNodeBuilder", "NPZFileNodeBuilder"] +__all__ = ["ZarrDatasetNodes", "NPZFileNodes"] diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/attributes.py similarity index 100% rename from src/anemoi/graphs/nodes/weights.py rename to src/anemoi/graphs/nodes/attributes.py diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 774a222..7c769cc 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -40,7 +40,7 @@ def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> Heter return graph -class ZarrDatasetNodeBuilder(BaseNodeBuilder): +class ZarrDatasetNodes(BaseNodeBuilder): """Nodes from Zarr dataset.""" def __init__(self, dataset: DotDict) -> None: @@ -51,7 +51,7 @@ def get_coordinates(self) -> torch.Tensor: return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) -class NPZFileNodeBuilder(BaseNodeBuilder): +class NPZFileNodes(BaseNodeBuilder): """Nodes from NPZ defined grids.""" def __init__(self, resolution: str, grid_definition_path: str) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 4fffb34..c074212 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,7 +60,7 @@ def config_file(tmp_path) -> tuple[str, str]: "nodes": { "test_nodes": { "node_builder": { - "_target_": "anemoi.graphs.nodes.NPZFileNodeBuilder", + "_target_": "anemoi.graphs.nodes.NPZFileNodes", "grid_definition_path": str(tmp_path), "resolution": "o16", }, @@ -70,7 +70,7 @@ def config_file(tmp_path) -> tuple[str, str]: { "nodes": {"src_name": "test_nodes", "dst_name": "test_nodes"}, "edge_builder": { - "_target_": "anemoi.graphs.edges.KNNEdgeBuilder", + "_target_": "anemoi.graphs.edges.KNNEdges", "num_nearest_neighbours": 3, }, "attributes": { diff --git a/tests/edges/test_cutoff.py b/tests/edges/test_cutoff.py index 844cf87..efe4ee8 100644 --- a/tests/edges/test_cutoff.py +++ b/tests/edges/test_cutoff.py @@ -1,22 +1,22 @@ import pytest -from anemoi.graphs.edges import CutOffEdgeBuilder +from anemoi.graphs.edges import CutOffEdges def test_init(): """Test CutOffEdgeBuilder initialization.""" - CutOffEdgeBuilder("test_nodes1", "test_nodes2", 0.5) + CutOffEdges("test_nodes1", "test_nodes2", 0.5) @pytest.mark.parametrize("cutoff_factor", [-0.5, "hello", None]) def test_fail_init(cutoff_factor: str): """Test CutOffEdgeBuilder initialization with invalid cutoff.""" with pytest.raises(AssertionError): - CutOffEdgeBuilder("test_nodes1", "test_nodes2", cutoff_factor) + CutOffEdges("test_nodes1", "test_nodes2", cutoff_factor) def test_cutoff(graph_with_nodes): """Test CutOffEdgeBuilder.""" - builder = CutOffEdgeBuilder("test_nodes", "test_nodes", 0.5) + builder = CutOffEdges("test_nodes", "test_nodes", 0.5) graph = builder.transform(graph_with_nodes) assert ("test_nodes", "to", "test_nodes") in graph.edge_types diff --git a/tests/edges/test_knn.py b/tests/edges/test_knn.py index aee529a..7149d0e 100644 --- a/tests/edges/test_knn.py +++ b/tests/edges/test_knn.py @@ -1,22 +1,22 @@ import pytest -from anemoi.graphs.edges import KNNEdgeBuilder +from anemoi.graphs.edges import KNNEdges def test_init(): """Test CutOffEdgeBuilder initialization.""" - KNNEdgeBuilder("test_nodes1", "test_nodes2", 3) + KNNEdges("test_nodes1", "test_nodes2", 3) @pytest.mark.parametrize("num_nearest_neighbours", [-1, 2.6, "hello", None]) def test_fail_init(num_nearest_neighbours: str): """Test KNNEdgeBuilder initialization with invalid number of nearest neighbours.""" with pytest.raises(AssertionError): - KNNEdgeBuilder("test_nodes1", "test_nodes2", num_nearest_neighbours) + KNNEdges("test_nodes1", "test_nodes2", num_nearest_neighbours) def test_knn(graph_with_nodes): """Test KNNEdgeBuilder.""" - builder = KNNEdgeBuilder("test_nodes", "test_nodes", 3) + builder = KNNEdges("test_nodes", "test_nodes", 3) graph = builder.transform(graph_with_nodes) assert ("test_nodes", "to", "test_nodes") in graph.edge_types diff --git a/tests/nodes/test_weights.py b/tests/nodes/test_node_attributes.py similarity index 93% rename from tests/nodes/test_weights.py rename to tests/nodes/test_node_attributes.py index 71e54fa..d5ccf89 100644 --- a/tests/nodes/test_weights.py +++ b/tests/nodes/test_node_attributes.py @@ -2,8 +2,8 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes.weights import AreaWeights -from anemoi.graphs.nodes.weights import UniformWeights +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights @pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"]) diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 61a5907..7220f40 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -2,17 +2,17 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes.builder import NPZFileNodeBuilder -from anemoi.graphs.nodes.weights import AreaWeights -from anemoi.graphs.nodes.weights import UniformWeights +from anemoi.graphs.nodes.builder import NPZFileNodes +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) def test_init(mock_grids_path: tuple[str, int], resolution: str): """Test NPZNodes initialization.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZFileNodeBuilder(resolution, grid_definition_path=grid_definition_path) - assert isinstance(node_builder, NPZFileNodeBuilder) + node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + assert isinstance(node_builder, NPZFileNodes) @pytest.mark.parametrize("resolution", ["o17", 13, "ajsnb", None]) @@ -20,13 +20,13 @@ def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution """Test NPZNodes initialization with invalid resolution.""" grid_definition_path, _ = mock_grids_path with pytest.raises(FileNotFoundError): - NPZFileNodeBuilder(resolution, grid_definition_path=grid_definition_path) + NPZFileNodes(resolution, grid_definition_path=grid_definition_path) def test_fail_init_wrong_path(): """Test NPZNodes initialization with invalid path.""" with pytest.raises(FileNotFoundError): - NPZFileNodeBuilder("o16", "invalid_path") + NPZFileNodes("o16", "invalid_path") @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) @@ -34,22 +34,22 @@ def test_register_nodes(mock_grids_path: str, resolution: str): """Test NPZNodes register correctly the nodes.""" graph = HeteroData() grid_definition_path, num_nodes = mock_grids_path - node_builder = NPZFileNodeBuilder(resolution, grid_definition_path=grid_definition_path) + node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) graph = node_builder.register_nodes(graph, "test_nodes") assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) assert graph["test_nodes"].x.shape == (num_nodes, 2) - assert graph["test_nodes"].node_type == "NPZFileNodeBuilder" + assert graph["test_nodes"].node_type == "NPZFileNodes" @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) -def test_register_weights(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): +def test_register_attributes(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): """Test NPZNodes register correctly the weights.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZFileNodeBuilder("o16", grid_definition_path=grid_definition_path) - config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} + node_builder = NPZFileNodes("o16", grid_definition_path=grid_definition_path) + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index 190e207..0e91ece 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -4,28 +4,28 @@ from torch_geometric.data import HeteroData from anemoi.graphs.nodes import builder -from anemoi.graphs.nodes.weights import AreaWeights -from anemoi.graphs.nodes.weights import UniformWeights +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights def test_init(mocker, mock_zarr_dataset): """Test ZarrNodes initialization.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodeBuilder("dataset.zarr") + node_builder = builder.ZarrDatasetNodes("dataset.zarr") assert isinstance(node_builder, builder.BaseNodeBuilder) - assert isinstance(node_builder, builder.ZarrDatasetNodeBuilder) + assert isinstance(node_builder, builder.ZarrDatasetNodes) def test_fail_init(): """Test ZarrNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): - builder.ZarrDatasetNodeBuilder("invalid_path.zarr") + builder.ZarrDatasetNodes("invalid_path.zarr") def test_register_nodes(mocker, mock_zarr_dataset): """Test ZarrNodes register correctly the nodes.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodeBuilder("dataset.zarr") + node_builder = builder.ZarrDatasetNodes("dataset.zarr") graph = HeteroData() graph = node_builder.register_nodes(graph, "test_nodes") @@ -33,15 +33,15 @@ def test_register_nodes(mocker, mock_zarr_dataset): assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) assert graph["test_nodes"].x.shape == (node_builder.ds.num_nodes, 2) - assert graph["test_nodes"].node_type == "ZarrDatasetNodeBuilder" + assert graph["test_nodes"].node_type == "ZarrDatasetNodes" @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) -def test_register_weights(mocker, graph_with_nodes: HeteroData, attr_class): +def test_register_attributes(mocker, graph_with_nodes: HeteroData, attr_class): """Test ZarrNodes register correctly the weights.""" mocker.patch.object(builder, "open_dataset", return_value=None) - node_builder = builder.ZarrDatasetNodeBuilder("dataset.zarr") - config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} + node_builder = builder.ZarrDatasetNodes("dataset.zarr") + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) From cdeaa03872135ddef43fa4cd75ae2d1e2dba4eb6 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 10:19:05 +0000 Subject: [PATCH 022/156] fix: pre-commit --- src/anemoi/graphs/edges/__init__.py | 1 - tests/nodes/test_npz.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 19d48db..53b9c74 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -2,4 +2,3 @@ from .builder import KNNEdges __all__ = ["KNNEdges", "CutOffEdges"] - diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 7220f40..fc4cf8c 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -2,9 +2,9 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes.builder import NPZFileNodes from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builder import NPZFileNodes @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) From f07434c8d1e9fc98a9623d3138f410135fe64dc5 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 12:06:22 +0000 Subject: [PATCH 023/156] feat: builders icosahedral --- src/anemoi/graphs/nodes/builder.py | 52 ++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 7c769cc..aa15b6f 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -16,6 +16,9 @@ class BaseNodeBuilder(ABC): """Base class for node builders.""" + def __init__(self) -> None: + self.aoi_mask_builder = None + def register_nodes(self, graph: HeteroData, name: str) -> None: graph[name].x = self.get_coordinates() graph[name].node_type = type(self).__name__ @@ -62,3 +65,52 @@ def __init__(self, resolution: str, grid_definition_path: str) -> None: def get_coordinates(self) -> np.ndarray: coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) return coords + + +class RefinedIcosahedralNodeBuilder(BaseNodeBuilder): + """Processor mesh based on a triangular mesh. + + It is based on the icosahedral mesh, which is a mesh of triangles that covers the sphere. + + Parameters + ---------- + resolution : list[int] | int + Refinement level of the mesh. + np_dtype : np.dtype, optional + The numpy data type to use, by default np.float32. + """ + + def __init__( + self, + resolution: Union[int, list[int]], + np_dtype: np.dtype = np.float32, + ) -> None: + self.np_dtype = np_dtype + + if isinstance(resolution, int): + self.resolutions = list(range(resolution + 1)) + else: + self.resolutions = resolution + + super().__init__() + + def get_coordinates(self) -> np.ndarray: + self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() + return coords_rad[self.node_ordering] + + def create_nodes(self) -> np.ndarray: ... + + def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: + graph[name]["resolutions"] = self.resolutions + graph[name]["nx_graph"] = self.nx_graph + graph[name]["node_ordering"] = self.node_ordering + graph[name]["aoi_mask_builder"] = self.aoi_mask_builder + return super().register_attributes(graph, name, config) + + +class TriRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): + """It depends on the trimesh Python library.""" + + def create_nodes(self) -> np.ndarray: + return create_icosahedral_nodes(resolutions=self.resolutions, aoi_nneighb=self.aoi_mask_builder) + From 52403d79897c6f57c59da3dff816f1bc7ac8f0c4 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 15:30:13 +0000 Subject: [PATCH 024/156] feat: Add icosahedral graph generation Co-authored-by: Mario Santa Cruz --- src/anemoi/graphs/generate/icosahedral.py | 216 ++++++++++++++++++++++ src/anemoi/graphs/nodes/builder.py | 10 +- 2 files changed, 221 insertions(+), 5 deletions(-) create mode 100644 src/anemoi/graphs/generate/icosahedral.py diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py new file mode 100644 index 0000000..d28d7ef --- /dev/null +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -0,0 +1,216 @@ +from collections.abc import Iterable +from typing import Optional + +import networkx as nx +import numpy as np +import trimesh +from sklearn.metrics.pairwise import haversine_distances +from sklearn.neighbors import BallTree + +from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad +import logging + +logger = logging.getLogger(__name__) + + +def create_icosahedral_nodes( + resolutions: list[int], +) -> tuple[nx.DiGraph, np.ndarray, list[int]]: + """Creates a global mesh following AIFS strategy. + + This method relies on the trimesh python library. + + Parameters + ---------- + resolutions : list[int] + Levels of mesh resolution to consider. + aoi_mask_builder : KNNAreaMaskBuilder + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. + + Returns + ------- + graph : networkx.Graph + The specified graph (nodes & edges). + vertices_rad : np.ndarray + The vertices (not ordered) of the mesh in radians. + node_ordering : list[int] + Order of the nodes in the graph to be sorted by latitude and longitude. + """ + sphere = create_sphere(resolutions[-1]) + coords_rad = cartesian_to_latlon_rad(sphere.vertices) + + node_ordering = get_node_ordering(coords_rad) + + # TODO: AOI mask builder is not used in the current implementation. + + nx_graph = create_icosahedral_nx_graph_from_coords(coords_rad, node_ordering) + + return nx_graph, coords_rad, list(node_ordering) + + +def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_ordering: list[int]): + + graph = nx.DiGraph() + for ii, coords in enumerate(coords_rad[node_ordering]): + node_id = node_ordering[ii] + graph.add_node(node_id, hcoords_rad=coords) + + assert list(graph.nodes.keys()) == list(node_ordering), "Nodes are not correctly added to the graph." + assert graph.number_of_nodes() == len(node_ordering), "The number of nodes must be the same." + return graph + + +def get_node_ordering(vertices_rad: np.ndarray) -> np.ndarray: + # Get indices to sort points by lon & lat in radians. + ind1 = np.argsort(vertices_rad[:, 1]) + ind2 = np.argsort(vertices_rad[ind1][:, 0])[::-1] + node_ordering = np.arange(vertices_rad.shape[0])[ind1][ind2] + return node_ordering + + +def add_edges_to_nx_graph( + graph: nx.DiGraph, + resolutions: list[int], + xhops: int = 1, +) -> None: + """Adds the edges to the graph. + + Parameters + ---------- + graph : nx.DiGraph + The graph to add the edges. It should correspond to the mesh nodes, without edges. + resolutions : list[int] + Levels of mesh refinement levels to consider. + xhops : int, optional + Number of hops between 2 nodes to consider them neighbours, by default 1. + aoi_mask_builder : KNNAreaMaskBuilder + NearestNeighbors with the cloud of points to limit the mesh area, by default None. + margin_radius_km : float, optional + Margin radius in km to consider when creating the processor mesh, by default 0.0. + """ + assert xhops > 0, "xhops == 0, graph would have no edges ..." + + sphere = create_sphere(resolutions[-1]) + vertices_rad = cartesian_to_latlon_rad(sphere.vertices) + x_hops = get_x_hops(sphere, xhops, valid_nodes=list(graph.nodes)) + + for i, i_neighbours in x_hops.items(): + add_neigbours_edges(graph, vertices_rad, i, i_neighbours) + + tree = BallTree(vertices_rad, metric="haversine") + + for resolution in resolutions[:-1]: + # Defined refined sphere + r_sphere = create_sphere(resolution) + r_vertices_rad = cartesian_to_latlon_rad(r_sphere.vertices) + + # TODO AOI mask builder is not used in the current implementation. + valid_nodes = None + + x_rings = get_x_hops(r_sphere, xhops, valid_nodes=valid_nodes) + + _, idx = tree.query(r_vertices_rad, k=1) + for i, i_neighbours in x_rings.items(): + add_neigbours_edges(graph, r_vertices_rad, i, i_neighbours, idx=idx) + + return graph + + +def create_sphere(subdivisions: int = 0, radius: float = 1.0) -> trimesh.Trimesh: + """Creates a sphere. + + Parameters + ---------- + subdivisions : int, optional + How many times to subdivide the mesh. Note that the number of faces will grow as function of 4 ** subdivisions. + Defaults to 0. + radius : float, optional + Radius of the sphere created, by default 1.0 + + Returns + ------- + trimesh.Trimesh + Meshed sphere. + """ + return trimesh.creation.icosphere(subdivisions=subdivisions, radius=radius) + + +def get_x_hops(sp: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[int]] = None) -> dict[int, set[int]]: + """Get the neigbour connections in the graph. + + Parameters + ---------- + sp : trimesh.Trimesh + The mesh to consider. + hops : int + Number of hops between 2 nodes to consider them neighbours. + valid_nodes : list[int], optional + List of valid nodes to consider, by default None. It is useful to consider only a subset of the nodes to save + computation time. + + Returns + ------- + neighbours : dict[int, set[int]] + A list with the neighbours for each vertex. The element at position 'i' correspond to the neighbours to the + i-th vertex of the mesh. + """ + edges = sp.edges_unique + if valid_nodes is not None: + edges = edges[np.isin(sp.edges_unique, valid_nodes).all(axis=1)] + else: + valid_nodes = list(range(len(sp.vertices))) + g = nx.from_edgelist(edges) + + neighbours = {ii: set(nx.ego_graph(g, ii, radius=hops, center=False) if ii in g else []) for ii in valid_nodes} + + return neighbours + + +def add_neigbours_edges( + graph: nx.Graph, + vertices: np.ndarray, + ii: int, + neighbours: Iterable[int], + self_loops: bool = False, + idx: Optional[np.ndarray] = None, +) -> None: + """Adds the edges of one node to its neighbours. + + Parameters + ---------- + graph : nx.Graph + The graph. + vertices : np.ndarray + A 2D array of shape (num_vertices, 2) with the planar coordinates of the mesh, in radians. + ii : int + The node considered. + neighbours : list[int] + The neighbours of the node. + self_loops : bool, optional + Whether is supported to add self-loops, by default False. + idx : np.ndarray, optional + Index to map the vertices from the refined sphere to the original one, by default None. + """ + for ineighb in neighbours: + if not self_loops and ii == ineighb: # no self-loops + continue + + loc_self = vertices[ii] + loc_neigh = vertices[ineighb] + edge_length = haversine_distances([loc_neigh, loc_self])[0][1] + + if idx is not None: + # Use the same method to add edge in all spheres + node_neigh = idx[ineighb][0] + node = idx[ii][0] + else: + node, node_neigh = ii, ineighb + + # add edge to the graph + if node in graph and node_neigh in graph: + graph.add_edge(node_neigh, node, weight=edge_length) + + + + + diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index aa15b6f..b003fa0 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -2,6 +2,7 @@ from abc import ABC from abc import abstractmethod from pathlib import Path +from typing import Union import numpy as np import torch @@ -9,6 +10,7 @@ from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch_geometric.data import HeteroData +from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes logger = logging.getLogger(__name__) @@ -16,9 +18,6 @@ class BaseNodeBuilder(ABC): """Base class for node builders.""" - def __init__(self) -> None: - self.aoi_mask_builder = None - def register_nodes(self, graph: HeteroData, name: str) -> None: graph[name].x = self.get_coordinates() graph[name].node_type = type(self).__name__ @@ -104,7 +103,7 @@ def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> graph[name]["resolutions"] = self.resolutions graph[name]["nx_graph"] = self.nx_graph graph[name]["node_ordering"] = self.node_ordering - graph[name]["aoi_mask_builder"] = self.aoi_mask_builder + # TODO: AOI mask builder is not used in the current implementation. return super().register_attributes(graph, name, config) @@ -112,5 +111,6 @@ class TriRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): """It depends on the trimesh Python library.""" def create_nodes(self) -> np.ndarray: - return create_icosahedral_nodes(resolutions=self.resolutions, aoi_nneighb=self.aoi_mask_builder) + # TODO: AOI mask builder is not used in the current implementation. + return create_icosahedral_nodes(resolutions=self.resolutions) From 2ef63c2c9dcc2dffc76692569b06f657509bd51b Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 15:38:13 +0000 Subject: [PATCH 025/156] refactor: remove create_shere --- src/anemoi/graphs/generate/icosahedral.py | 54 +++++++------------ src/anemoi/graphs/nodes/nodes.py | 65 +++++++++++++++++++++++ 2 files changed, 83 insertions(+), 36 deletions(-) create mode 100644 src/anemoi/graphs/nodes/nodes.py diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index d28d7ef..97c4801 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -36,7 +36,8 @@ def create_icosahedral_nodes( node_ordering : list[int] Order of the nodes in the graph to be sorted by latitude and longitude. """ - sphere = create_sphere(resolutions[-1]) + sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) + coords_rad = cartesian_to_latlon_rad(sphere.vertices) node_ordering = get_node_ordering(coords_rad) @@ -90,7 +91,7 @@ def add_edges_to_nx_graph( """ assert xhops > 0, "xhops == 0, graph would have no edges ..." - sphere = create_sphere(resolutions[-1]) + sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) vertices_rad = cartesian_to_latlon_rad(sphere.vertices) x_hops = get_x_hops(sphere, xhops, valid_nodes=list(graph.nodes)) @@ -101,7 +102,7 @@ def add_edges_to_nx_graph( for resolution in resolutions[:-1]: # Defined refined sphere - r_sphere = create_sphere(resolution) + r_sphere = trimesh.creation.icosphere(subdivisions=resolution, radius=1.0) r_vertices_rad = cartesian_to_latlon_rad(r_sphere.vertices) # TODO AOI mask builder is not used in the current implementation. @@ -116,31 +117,12 @@ def add_edges_to_nx_graph( return graph -def create_sphere(subdivisions: int = 0, radius: float = 1.0) -> trimesh.Trimesh: - """Creates a sphere. - - Parameters - ---------- - subdivisions : int, optional - How many times to subdivide the mesh. Note that the number of faces will grow as function of 4 ** subdivisions. - Defaults to 0. - radius : float, optional - Radius of the sphere created, by default 1.0 - - Returns - ------- - trimesh.Trimesh - Meshed sphere. - """ - return trimesh.creation.icosphere(subdivisions=subdivisions, radius=radius) - - -def get_x_hops(sp: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[int]] = None) -> dict[int, set[int]]: +def get_x_hops(tri_mesh: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[int]] = None) -> dict[int, set[int]]: """Get the neigbour connections in the graph. Parameters ---------- - sp : trimesh.Trimesh + tri_mesh : trimesh.Trimesh The mesh to consider. hops : int Number of hops between 2 nodes to consider them neighbours. @@ -154,11 +136,11 @@ def get_x_hops(sp: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[int]] A list with the neighbours for each vertex. The element at position 'i' correspond to the neighbours to the i-th vertex of the mesh. """ - edges = sp.edges_unique + edges = tri_mesh.edges_unique if valid_nodes is not None: - edges = edges[np.isin(sp.edges_unique, valid_nodes).all(axis=1)] + edges = edges[np.isin(tri_mesh.edges_unique, valid_nodes).all(axis=1)] else: - valid_nodes = list(range(len(sp.vertices))) + valid_nodes = list(range(len(tri_mesh.vertices))) g = nx.from_edgelist(edges) neighbours = {ii: set(nx.ego_graph(g, ii, radius=hops, center=False) if ii in g else []) for ii in valid_nodes} @@ -191,24 +173,24 @@ def add_neigbours_edges( idx : np.ndarray, optional Index to map the vertices from the refined sphere to the original one, by default None. """ - for ineighb in neighbours: - if not self_loops and ii == ineighb: # no self-loops + for idx_neighbour in neighbours: + if not self_loops and ii == idx_neighbour: # no self-loops continue - loc_self = vertices[ii] - loc_neigh = vertices[ineighb] - edge_length = haversine_distances([loc_neigh, loc_self])[0][1] + location_node = vertices[ii] + location_neighbour = vertices[idx_neighbour] + edge_length = haversine_distances([location_neighbour, location_node])[0][1] if idx is not None: # Use the same method to add edge in all spheres - node_neigh = idx[ineighb][0] + node_neighbour = idx[idx_neighbour][0] node = idx[ii][0] else: - node, node_neigh = ii, ineighb + node, node_neighbour = ii, idx_neighbour # add edge to the graph - if node in graph and node_neigh in graph: - graph.add_edge(node_neigh, node, weight=edge_length) + if node in graph and node_neighbour in graph: + graph.add_edge(node_neighbour, node, weight=edge_length) diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py new file mode 100644 index 0000000..3d59e5f --- /dev/null +++ b/src/anemoi/graphs/nodes/nodes.py @@ -0,0 +1,65 @@ +import logging +from abc import ABC +from abc import abstractmethod +from pathlib import Path + +import numpy as np +import torch +from anemoi.datasets import open_dataset +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + +logger = logging.getLogger(__name__) + + +class BaseNodeBuilder(ABC): + """Base class for node builders.""" + + def register_nodes(self, graph: HeteroData, name: str) -> None: + graph[name].x = self.get_coordinates() + graph[name].node_type = type(self).__name__ + return graph + + def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: + for nodes_attr_name, attr_cfg in config.items(): + graph[name][nodes_attr_name] = instantiate(attr_cfg).get_weights(graph[name]) + return graph + + @abstractmethod + def get_coordinates(self) -> np.ndarray: ... + + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: + coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) + coords = np.deg2rad(coords) + return torch.tensor(coords, dtype=torch.float32) + + def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: + graph = self.register_nodes(graph, name) + graph = self.register_attributes(graph, name, attr_config) + return graph + + + +class ZarrNodes(BaseNodeBuilder): + """Nodes from Zarr dataset.""" + + def __init__(self, dataset: DotDict) -> None: + logger.info("Reading the dataset from %s.", dataset) + self.ds = open_dataset(dataset) + + def get_coordinates(self) -> torch.Tensor: + return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) + + +class NPZNodes(BaseNodeBuilder): + """Nodes from NPZ defined grids.""" + + def __init__(self, resolution: str, grid_definition_path: str) -> None: + self.resolution = resolution + self.grid_definition_path = grid_definition_path + self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + + def get_coordinates(self) -> np.ndarray: + coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) + return coords From 1fe76bbd3e60a6cdc3ce1d4a5da95f81fbb20a91 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 15:46:35 +0000 Subject: [PATCH 026/156] feat: Icosahedral edge builder --- src/anemoi/graphs/edges/builder.py | 45 ++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 759502d..006d143 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -4,6 +4,7 @@ import numpy as np import torch +import networkx as nx from anemoi.utils.config import DotDict from hydra.utils import instantiate from sklearn.neighbors import NearestNeighbors @@ -12,6 +13,8 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.utils import get_grid_reference_distance +from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder +from anemoi.graphs.generate import icosahedral logger = logging.getLogger(__name__) @@ -119,3 +122,45 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): nearest_neighbour.fit(src_nodes.x) adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo() return adj_matrix + + +class TriIcosahedralEdgeBuilder(BaseEdgeBuilder): + """Computes icosahedral edges and adds them to a HeteroData graph.""" + + def __init__(self, src_name: str, dst_name: str, xhops: int): + super().__init__(src_name, dst_name) + + assert isinstance(xhops, int), "Number of xhops must be an integer" + assert xhops > 0, "Number of xhops must be positive" + + self.xhops = xhops + + def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: + + assert ( + graph[self.src_name].node_type == TriRefinedIcosahedralNodeBuilder.__name__ + ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." + assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination." + + # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate + # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." + + return super().transform(graph, edge_name, attrs_config) + + def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): + + src_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph( + src_nodes["nx_graph"], + resolutions=src_nodes["resolutions"], + xhops=self.xhops, + aoi_nneighb=None if "aoi_nneighb" not in src_nodes else src_nodes["aoi_nneigh"], + ) # HeteroData refuses to accept None + + adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], nodelist=list(src_nodes["nx_graph"]), format="coo") + graph_1_sorted = dict(zip(range(len(src_nodes["nx_graph"].nodes)), list(src_nodes["nx_graph"].nodes))) + graph_2_sorted = dict(zip(src_nodes.node_ordering, range(len(src_nodes.node_ordering)))) + sort_func1 = np.vectorize(graph_1_sorted.get) + sort_func2 = np.vectorize(graph_2_sorted.get) + adjmat.row = sort_func2(sort_func1(adjmat.row)) + adjmat.col = sort_func2(sort_func1(adjmat.col)) + return adjmat From fde0fe65fdca5316789fdf4558fafbb566e25bcc Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 16:08:38 +0000 Subject: [PATCH 027/156] feat: hexagonal graph generation Co-authored-by: Mario Santa Cruz --- src/anemoi/graphs/generate/hexagonal.py | 260 ++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 src/anemoi/graphs/generate/hexagonal.py diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py new file mode 100644 index 0000000..902d88b --- /dev/null +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -0,0 +1,260 @@ +from typing import Optional + +import h3 +import networkx as nx +import numpy as np +import torch +from sklearn.metrics.pairwise import haversine_distances + + +def add_edge( + graph: nx.Graph, + idx1: str, + idx2: str, + allow_self_loop: bool = False, +) -> None: + """Add edge between two nodes to a graph. + + The edge will only be added in case both tail and head nodes are included in the graph, G. + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + idx1 : str + The H3 index of the tail of the edge. + idx2 : str + The H3 index of the head of the edge. + allow_self_loop : bool + Whether to allow self-loops or not. Defaults to not allowing self-loops. + """ + if not graph.has_node(idx1) or not graph.has_node(idx2): + return + + if allow_self_loop or idx1 != idx2: + loc1 = np.deg2rad(h3.h3_to_geo(idx1)) + loc2 = np.deg2rad(h3.h3_to_geo(idx2)) + graph.add_edge(idx1, idx2, weight=haversine_distances([loc1, loc2])[0][1]) + + +def get_cells_at_resolution( + resolution: int, + area: Optional[dict] = None, + aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, +) -> set[str]: + """Get cells at a specified refinement level. + + Parameters + ---------- + resolution : int + The H3 refinement level. It can be an integer from 0 to 15. + area : dict + A region, in GeoJSON data format, to be contained by all cells. Defaults to None. + aoi_mask_builder : KNNAreaMaskBuilder, optional + KNNAreaMaskBuilder computes nask to limit the mesh area, by default None. + + Returns + ------- + cells : set[str] + The set of H3 indexes at the specified resolution level. + """ + # TODO: What is area? + cells = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) + + if aoi_mask_builder is not None: + cells = list(cells) + + coords = np.deg2rad(np.array([h3.h3_to_geo(c) for c in cells])) + aoi_mask = aoi_mask_builder.get_mask(coords) + + cells = set(map(str, np.array(cells)[aoi_mask])) + + return cells + + +def add_nodes_for_resolution( + graph: nx.Graph, + resolution: int, + self_loop: bool = False, + **area_kwargs: Optional[dict], +) -> None: + """Add all nodes at a specified refinement level to a graph. + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + resolution : int + The H3 refinement level. It can be an integer from 0 to 15. + self_loop : int + Whether to include self-loops in the nodes added or not. + area_kwargs: dict + Additional arguments to pass to the get_cells_at_resolution function. + """ + for idx in get_cells_at_resolution(resolution, **area_kwargs): + graph.add_node(idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(idx))) + if self_loop: + # TODO: should that be add_self_loops(graph)? + add_edge(graph, idx, idx, allow_self_loop=self_loop) + + +def add_neighbour_edges( + graph: nx.Graph, + refinement_levels: tuple[int], + flat: bool = True, +) -> None: + for resolution in refinement_levels: + cells = {node for node in graph.nodes if h3.h3_get_resolution(node) == resolution} + for idx in cells: + k = 2 if resolution == 0 else 1 # refinement_levels[0]: # extra large field of vision ; only few nodes + + # neighbours + for idx_neighbour in h3.k_ring(idx, k=k) & cells: + if flat: + add_edge( + graph, + h3.h3_to_center_child(idx, refinement_levels[-1]), + h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]), + ) + else: + add_edge(graph, idx, idx_neighbour) + + +def create_hexagonal_nodes( + resolutions: list[int], + flat: bool = True, + area: Optional[dict] = None, + aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, +) -> tuple[nx.Graph, torch.Tensor, list[int]]: + """Creates a global mesh from a refined icosahedro. + + This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each + refinement level, a hexagon cell has 7 child cells (aperture 7). + + Parameters + ---------- + resolutions : list[int] + Levels of mesh resolution to consider. + flat : bool + Whether or not all resolution levels of the mesh are included. + area : dict + A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global + mesh. + aoi_mask_builder : KNNAreaMaskBuilder, optional + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. + + Returns + ------- + graph : networkx.Graph + The specified graph (nodes & edges). + """ + graph = nx.Graph() + + area_kwargs = {"area": area, "aoi_mask_builder": aoi_mask_builder} + + for resolution in resolutions: + add_nodes_for_resolution(graph, resolution, **area_kwargs) + + coords = np.array([h3.h3_to_geo(node) for node in graph.nodes]) + + # Sort nodes by latitude and longitude + node_ordering = np.lexsort(coords.T[::-1], axis=0) + + # Should these be sorted here or in the edge builder? + coords = coords[node_ordering] + + return graph, coords, node_ordering + + +def add_edges_to_nx_graph( + graph: nx.Graph, + resolutions: list[int], + self_loop: bool = False, + flat: bool = True, + neighbour_children: bool = False, + depth_children: int = 1, +) -> nx.Graph: + """Creates a global mesh from a refined icosahedro. + + This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each + refinement level, a hexagon cell has 7 child cells (aperture 7). + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + resolutions : list[int] + Levels of mesh resolution to consider. + self_loop : bool + Whether include a self-loop in every node or not. + flat : bool + Whether or not all resolution levels of the mesh are included. + neighbour_children : bool + Whether to include connections with the children from the neighbours. + depth_children : int + The number of resolution levels to consider for the connections of children. Defaults to 1, which includes + connections up to the next resolution level. + + Returns + ------- + graph : networkx.Graph + The specified graph (nodes & edges). + """ + if self_loop: + add_self_loops(graph) + + add_neighbour_edges(graph, resolutions, flat) + add_children_edges( + graph, + resolutions, + flat, + neighbour_children, + depth_children, + ) + return graph + + +def add_self_loops(graph: nx.Graph) -> None: + + for idx in graph.nodes: + add_edge(graph, idx, idx, allow_self_loop=True) + + +def add_children_edges( + graph: nx.Graph, + refinement_levels: tuple[int], + flat: bool = True, + neighbour_children: bool = False, + depth: Optional[int] = None, +) -> None: + if depth is None: + depth = len(refinement_levels) + + for ip, resolution_parent in enumerate(refinement_levels[0:-1]): + parent_cells = [node for node in graph.nodes if h3.h3_get_resolution(node) == resolution_parent] + for idx_parent in parent_cells: + # add own children + for resolution_child in refinement_levels[ip + 1 : ip + depth + 1]: + for idx_child in h3.h3_to_children(idx_parent, res=resolution_child): + if flat: + add_edge( + graph, + h3.h3_to_center_child(idx_parent, refinement_levels[-1]), + h3.h3_to_center_child(idx_child, refinement_levels[-1]), + ) + else: + add_edge(graph, idx_parent, idx_child) + + # add neighbour children + if neighbour_children: + for idx_parent_neighbour in h3.k_ring(idx_parent, k=1) & parent_cells: + for resolution_child in refinement_levels[ip + 1 : ip + depth + 1]: + for idx_child_neighbour in h3.h3_to_children(idx_parent_neighbour, res=resolution_child): + if flat: + add_edge( + graph, + h3.h3_to_center_child(idx_parent, refinement_levels[-1]), + h3.h3_to_center_child(idx_child_neighbour, refinement_levels[-1]), + ) + else: + add_edge(graph, idx_parent, idx_child_neighbour) From 97ef0da641cdd88b84f67ab471a96399c67d092a Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 16:10:01 +0000 Subject: [PATCH 028/156] feat: hexagonal builders --- src/anemoi/graphs/edges/builder.py | 37 ++++++++++++++++++++++++++++++ src/anemoi/graphs/nodes/builder.py | 7 ++++++ 2 files changed, 44 insertions(+) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 006d143..bd4f4cd 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -14,6 +14,7 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.utils import get_grid_reference_distance from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder +from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder from anemoi.graphs.generate import icosahedral logger = logging.getLogger(__name__) @@ -164,3 +165,39 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): adjmat.row = sort_func2(sort_func1(adjmat.row)) adjmat.col = sort_func2(sort_func1(adjmat.col)) return adjmat + + +class HexagonalEdgeBuilder(BaseEdgeBuilder): + """Computes hexagonal edges and adds them to a HeteroData graph.""" + + def __init__(self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): + super().__init__(src_name, dst_name) + self.add_neighbouring_children = add_neighbouring_children + self.depth_children = depth_children + + def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: + assert ( + graph[self.src_name].node_type == HexRefinedIcosahedralNodeBuilder.__name__ + ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." + assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination." + + # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate + # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." + + return super().transform(graph, edge_name, attrs_config) + + def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): + + src_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph( + src_nodes["nx_graph"], + resolutions=src_nodes["resolutions"], + neighbour_children=self.add_neighbouring_children, + depth_children=self.depth_children, + ) + + adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], format="coo") + graph_2_sorted = dict(zip(src_nodes["node_ordering"], range(len(src_nodes.node_ordering)))) + sort_func = np.vectorize(graph_2_sorted.get) + adjmat.row = sort_func(adjmat.row) + adjmat.col = sort_func(adjmat.col) + return adjmat \ No newline at end of file diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index b003fa0..8d74c09 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -11,6 +11,7 @@ from hydra.utils import instantiate from torch_geometric.data import HeteroData from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes +from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes logger = logging.getLogger(__name__) @@ -114,3 +115,9 @@ def create_nodes(self) -> np.ndarray: # TODO: AOI mask builder is not used in the current implementation. return create_icosahedral_nodes(resolutions=self.resolutions) + +class HexRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): + """It depends on the h3 Python library.""" + + def create_nodes(self) -> np.ndarray: + return create_hexagonal_nodes(self.resolutions) From 8dcda40844e030348816e9471a5fc81c7b101e64 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 16:13:05 +0000 Subject: [PATCH 029/156] fix: AOI not implemented yet --- src/anemoi/graphs/generate/hexagonal.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 902d88b..66e18f6 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -40,7 +40,6 @@ def add_edge( def get_cells_at_resolution( resolution: int, area: Optional[dict] = None, - aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, ) -> set[str]: """Get cells at a specified refinement level. @@ -61,13 +60,7 @@ def get_cells_at_resolution( # TODO: What is area? cells = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) - if aoi_mask_builder is not None: - cells = list(cells) - - coords = np.deg2rad(np.array([h3.h3_to_geo(c) for c in cells])) - aoi_mask = aoi_mask_builder.get_mask(coords) - - cells = set(map(str, np.array(cells)[aoi_mask])) + # TODO: AOI not used in the current implementation. return cells @@ -124,7 +117,6 @@ def create_hexagonal_nodes( resolutions: list[int], flat: bool = True, area: Optional[dict] = None, - aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, ) -> tuple[nx.Graph, torch.Tensor, list[int]]: """Creates a global mesh from a refined icosahedro. @@ -150,7 +142,7 @@ def create_hexagonal_nodes( """ graph = nx.Graph() - area_kwargs = {"area": area, "aoi_mask_builder": aoi_mask_builder} + area_kwargs = {"area": area} for resolution in resolutions: add_nodes_for_resolution(graph, resolution, **area_kwargs) From 63be6af37b91cca06ebe793c9603848eb5639eea Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 11:06:54 +0000 Subject: [PATCH 030/156] fix: abstractmethod and renaming --- src/anemoi/graphs/edges/builder.py | 27 +++++++++++++++++---------- src/anemoi/graphs/nodes/builder.py | 12 +++++++----- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index bd4f4cd..28b40c1 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -2,9 +2,9 @@ from abc import abstractmethod from typing import Optional +import networkx as nx import numpy as np import torch -import networkx as nx from anemoi.utils.config import DotDict from hydra.utils import instantiate from sklearn.neighbors import NearestNeighbors @@ -12,10 +12,11 @@ from torch_geometric.data.storage import NodeStorage from anemoi.graphs import EARTH_RADIUS -from anemoi.graphs.utils import get_grid_reference_distance -from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder -from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder +from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral +from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder +from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder +from anemoi.graphs.utils import get_grid_reference_distance logger = logging.getLogger(__name__) @@ -125,7 +126,7 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): return adj_matrix -class TriIcosahedralEdgeBuilder(BaseEdgeBuilder): +class TriIcosahedralEdges(BaseEdgeBuilder): """Computes icosahedral edges and adds them to a HeteroData graph.""" def __init__(self, src_name: str, dst_name: str, xhops: int): @@ -141,7 +142,9 @@ def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[Do assert ( graph[self.src_name].node_type == TriRefinedIcosahedralNodeBuilder.__name__ ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." - assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination." + assert ( + graph[self.src_name] == graph[self.dst_name] + ), "InheritConnection requires the same nodes for source and destination." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." @@ -167,10 +170,12 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): return adjmat -class HexagonalEdgeBuilder(BaseEdgeBuilder): +class HexagonalEdges(BaseEdgeBuilder): """Computes hexagonal edges and adds them to a HeteroData graph.""" - def __init__(self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): + def __init__( + self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1 + ): super().__init__(src_name, dst_name) self.add_neighbouring_children = add_neighbouring_children self.depth_children = depth_children @@ -179,7 +184,9 @@ def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[Do assert ( graph[self.src_name].node_type == HexRefinedIcosahedralNodeBuilder.__name__ ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." - assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination." + assert ( + graph[self.src_name] == graph[self.dst_name] + ), "InheritConnection requires the same nodes for source and destination." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." @@ -200,4 +207,4 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): sort_func = np.vectorize(graph_2_sorted.get) adjmat.row = sort_func(adjmat.row) adjmat.col = sort_func(adjmat.col) - return adjmat \ No newline at end of file + return adjmat diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 8d74c09..6ea6bf0 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -10,8 +10,9 @@ from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch_geometric.data import HeteroData -from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes + from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes +from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes logger = logging.getLogger(__name__) @@ -67,7 +68,7 @@ def get_coordinates(self) -> np.ndarray: return coords -class RefinedIcosahedralNodeBuilder(BaseNodeBuilder): +class RefinedIcosahedralNodes(BaseNodeBuilder, ABC): """Processor mesh based on a triangular mesh. It is based on the icosahedral mesh, which is a mesh of triangles that covers the sphere. @@ -98,6 +99,7 @@ def get_coordinates(self) -> np.ndarray: self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() return coords_rad[self.node_ordering] + @abstractmethod def create_nodes(self) -> np.ndarray: ... def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: @@ -108,15 +110,15 @@ def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> return super().register_attributes(graph, name, config) -class TriRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): +class TriRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the trimesh Python library.""" def create_nodes(self) -> np.ndarray: # TODO: AOI mask builder is not used in the current implementation. - return create_icosahedral_nodes(resolutions=self.resolutions) + return create_icosahedral_nodes(resolutions=self.resolutions) -class HexRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): +class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the h3 Python library.""" def create_nodes(self) -> np.ndarray: From b175585f22b7c83e57c70069d42b0e32e5a38d79 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 12:55:02 +0000 Subject: [PATCH 031/156] chore: add dependencies --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index cb5bb7f..6654b59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,12 @@ dynamic = [ dependencies = [ "anemoi-datasets[data]>=0.3.3", "anemoi-utils>=0.3.6", + "h3>=3.7.6", "hydra-core>=1.3", + "networkx>=3.1", "torch>=2.2", "torch-geometric>=2.3.1,<2.5", + "trimesh>=4.1", ] optional-dependencies.all = [ From 86c5e354290abd82bb243faa48a948ba159da4a9 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 13:49:07 +0000 Subject: [PATCH 032/156] test: add tests for trimesh --- tests/nodes/test_tri_refined_icosahedral.py | 33 +++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/nodes/test_tri_refined_icosahedral.py diff --git a/tests/nodes/test_tri_refined_icosahedral.py b/tests/nodes/test_tri_refined_icosahedral.py new file mode 100644 index 0000000..762efdf --- /dev/null +++ b/tests/nodes/test_tri_refined_icosahedral.py @@ -0,0 +1,33 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes import builder + + +@pytest.mark.parametrize("resolution", [0, 2]) +def test_init(resolution: list[int]): + """Test TrirefinedIcosahedralNodes initialization.""" + + node_builder = builder.TriRefinedIcosahedralNodes(resolution) + assert isinstance(node_builder, builder.BaseNodeBuilder) + assert isinstance(node_builder, builder.TriRefinedIcosahedralNodes) + + +def test_get_coordinates(): + """Test get_coordinates method.""" + node_builder = builder.TriRefinedIcosahedralNodes(2) + coords = node_builder.get_coordinates() + assert isinstance(coords, torch.Tensor) + assert coords.shape == (162, 2) + + +def test_transform(): + """Test transform method.""" + node_builder = builder.TriRefinedIcosahedralNodes(1) + graph = HeteroData() + graph = node_builder.transform(graph, "test", {}) + assert "resolutions" in graph["test"] + assert "nx_graph" in graph["test"] + assert "node_ordering" in graph["test"] + assert len(graph["test"]["node_ordering"]) == graph["test"].num_nodes From 19461a1bb7496ff508590f22eb3c670173de99c7 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:02:05 +0000 Subject: [PATCH 033/156] test: add tests for hex (h3) --- tests/nodes/test_hex_refined_icosahedral.py | 33 +++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/nodes/test_hex_refined_icosahedral.py diff --git a/tests/nodes/test_hex_refined_icosahedral.py b/tests/nodes/test_hex_refined_icosahedral.py new file mode 100644 index 0000000..df0e716 --- /dev/null +++ b/tests/nodes/test_hex_refined_icosahedral.py @@ -0,0 +1,33 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes import builder + + +@pytest.mark.parametrize("resolution", [0, 2]) +def test_init(resolution: list[int]): + """Test TrirefinedIcosahedralNodes initialization.""" + + node_builder = builder.HexRefinedIcosahedralNodes(resolution) + assert isinstance(node_builder, builder.BaseNodeBuilder) + assert isinstance(node_builder, builder.HexRefinedIcosahedralNodes) + + +def test_get_coordinates(): + """Test get_coordinates method.""" + node_builder = builder.HexRefinedIcosahedralNodes(0) + coords = node_builder.get_coordinates() + assert isinstance(coords, torch.Tensor) + assert coords.shape == (122, 2) + + +def test_transform(): + """Test transform method.""" + node_builder = builder.HexRefinedIcosahedralNodes(0) + graph = HeteroData() + graph = node_builder.transform(graph, "test", {}) + assert "resolutions" in graph["test"] + assert "nx_graph" in graph["test"] + assert "node_ordering" in graph["test"] + assert len(graph["test"]["node_ordering"]) == graph["test"].num_nodes From 39ee3adb27001ec676767144fb310f52c77487bc Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:03:03 +0000 Subject: [PATCH 034/156] fix: imports --- src/anemoi/graphs/nodes/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 737f27f..ef13e41 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,6 @@ +from .builder import HexRefinedIcosahedralNodes from .builder import NPZFileNodes +from .builder import TriRefinedIcosahedralNodes from .builder import ZarrDatasetNodes -__all__ = ["ZarrDatasetNodes", "NPZFileNodes"] +__all__ = ["ZarrDatasetNodes", "NPZFileNodes", "TriRefinedIcosahedralNodes", "HexRefinedIcosahedralNodes"] From f00fd72c673af676a8a67c2b2a08710dbf0101a6 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:03:44 +0000 Subject: [PATCH 035/156] fix: output type --- src/anemoi/graphs/nodes/builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 6ea6bf0..a477020 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -86,6 +86,7 @@ def __init__( resolution: Union[int, list[int]], np_dtype: np.dtype = np.float32, ) -> None: + # TODO: Discuss np_dtype self.np_dtype = np_dtype if isinstance(resolution, int): @@ -95,9 +96,9 @@ def __init__( super().__init__() - def get_coordinates(self) -> np.ndarray: + def get_coordinates(self) -> torch.Tensor: self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() - return coords_rad[self.node_ordering] + return torch.tensor(coords_rad[self.node_ordering]) @abstractmethod def create_nodes(self) -> np.ndarray: ... @@ -122,4 +123,5 @@ class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the h3 Python library.""" def create_nodes(self) -> np.ndarray: + # TODO: AOI mask builder is not used in the current implementation. return create_hexagonal_nodes(self.resolutions) From 75a82c830e2d28d325cf1edfb000e2ea19888e2d Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:06:19 +0000 Subject: [PATCH 036/156] refactor: delete unused file --- src/anemoi/graphs/nodes/nodes.py | 65 -------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 src/anemoi/graphs/nodes/nodes.py diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py deleted file mode 100644 index 3d59e5f..0000000 --- a/src/anemoi/graphs/nodes/nodes.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -from abc import ABC -from abc import abstractmethod -from pathlib import Path - -import numpy as np -import torch -from anemoi.datasets import open_dataset -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from torch_geometric.data import HeteroData - -logger = logging.getLogger(__name__) - - -class BaseNodeBuilder(ABC): - """Base class for node builders.""" - - def register_nodes(self, graph: HeteroData, name: str) -> None: - graph[name].x = self.get_coordinates() - graph[name].node_type = type(self).__name__ - return graph - - def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: - for nodes_attr_name, attr_cfg in config.items(): - graph[name][nodes_attr_name] = instantiate(attr_cfg).get_weights(graph[name]) - return graph - - @abstractmethod - def get_coordinates(self) -> np.ndarray: ... - - def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: - coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) - coords = np.deg2rad(coords) - return torch.tensor(coords, dtype=torch.float32) - - def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: - graph = self.register_nodes(graph, name) - graph = self.register_attributes(graph, name, attr_config) - return graph - - - -class ZarrNodes(BaseNodeBuilder): - """Nodes from Zarr dataset.""" - - def __init__(self, dataset: DotDict) -> None: - logger.info("Reading the dataset from %s.", dataset) - self.ds = open_dataset(dataset) - - def get_coordinates(self) -> torch.Tensor: - return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) - - -class NPZNodes(BaseNodeBuilder): - """Nodes from NPZ defined grids.""" - - def __init__(self, resolution: str, grid_definition_path: str) -> None: - self.resolution = resolution - self.grid_definition_path = grid_definition_path - self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") - - def get_coordinates(self) -> np.ndarray: - coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) - return coords From f45b900f3afefff9e8f00e445bd0528fd5cf7290 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 15:50:54 +0000 Subject: [PATCH 037/156] refactor: renaming and positioning --- src/anemoi/graphs/generate/hexagonal.py | 199 +++++++++++----------- src/anemoi/graphs/generate/icosahedral.py | 41 ++--- 2 files changed, 119 insertions(+), 121 deletions(-) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 66e18f6..4ab9569 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -7,62 +7,49 @@ from sklearn.metrics.pairwise import haversine_distances -def add_edge( - graph: nx.Graph, - idx1: str, - idx2: str, - allow_self_loop: bool = False, -) -> None: - """Add edge between two nodes to a graph. - - The edge will only be added in case both tail and head nodes are included in the graph, G. - - Parameters - ---------- - graph : networkx.Graph - The graph to add the nodes. - idx1 : str - The H3 index of the tail of the edge. - idx2 : str - The H3 index of the head of the edge. - allow_self_loop : bool - Whether to allow self-loops or not. Defaults to not allowing self-loops. - """ - if not graph.has_node(idx1) or not graph.has_node(idx2): - return - - if allow_self_loop or idx1 != idx2: - loc1 = np.deg2rad(h3.h3_to_geo(idx1)) - loc2 = np.deg2rad(h3.h3_to_geo(idx2)) - graph.add_edge(idx1, idx2, weight=haversine_distances([loc1, loc2])[0][1]) - - -def get_cells_at_resolution( - resolution: int, +def create_hexagonal_nodes( + resolutions: list[int], + flat: bool = True, area: Optional[dict] = None, -) -> set[str]: - """Get cells at a specified refinement level. +) -> tuple[nx.Graph, torch.Tensor, list[int]]: + """Creates a global mesh from a refined icosahedro. + + This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each + refinement level, a hexagon cell has 7 child cells (aperture 7). Parameters ---------- - resolution : int - The H3 refinement level. It can be an integer from 0 to 15. + resolutions : list[int] + Levels of mesh resolution to consider. + flat : bool + Whether or not all resolution levels of the mesh are included. area : dict - A region, in GeoJSON data format, to be contained by all cells. Defaults to None. + A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global + mesh. aoi_mask_builder : KNNAreaMaskBuilder, optional - KNNAreaMaskBuilder computes nask to limit the mesh area, by default None. + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. Returns ------- - cells : set[str] - The set of H3 indexes at the specified resolution level. + graph : networkx.Graph + The specified graph (nodes & edges). """ - # TODO: What is area? - cells = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) + graph = nx.Graph() - # TODO: AOI not used in the current implementation. + area_kwargs = {"area": area} - return cells + for resolution in resolutions: + add_nodes_for_resolution(graph, resolution, **area_kwargs) + + coords = np.array([h3.h3_to_geo(node) for node in graph.nodes]) + + # Sort nodes by latitude and longitude + node_ordering = np.lexsort(coords.T[::-1], axis=0) + + # Should these be sorted here or in the edge builder? + coords = coords[node_ordering] + + return graph, coords, node_ordering def add_nodes_for_resolution( @@ -84,78 +71,42 @@ def add_nodes_for_resolution( area_kwargs: dict Additional arguments to pass to the get_cells_at_resolution function. """ - for idx in get_cells_at_resolution(resolution, **area_kwargs): + + cells = get_cells_at_resolution(resolution, **area_kwargs) + + for idx in cells: graph.add_node(idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(idx))) if self_loop: # TODO: should that be add_self_loops(graph)? add_edge(graph, idx, idx, allow_self_loop=self_loop) -def add_neighbour_edges( - graph: nx.Graph, - refinement_levels: tuple[int], - flat: bool = True, -) -> None: - for resolution in refinement_levels: - cells = {node for node in graph.nodes if h3.h3_get_resolution(node) == resolution} - for idx in cells: - k = 2 if resolution == 0 else 1 # refinement_levels[0]: # extra large field of vision ; only few nodes - - # neighbours - for idx_neighbour in h3.k_ring(idx, k=k) & cells: - if flat: - add_edge( - graph, - h3.h3_to_center_child(idx, refinement_levels[-1]), - h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]), - ) - else: - add_edge(graph, idx, idx_neighbour) - - -def create_hexagonal_nodes( - resolutions: list[int], - flat: bool = True, +def get_cells_at_resolution( + resolution: int, area: Optional[dict] = None, -) -> tuple[nx.Graph, torch.Tensor, list[int]]: - """Creates a global mesh from a refined icosahedro. - - This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each - refinement level, a hexagon cell has 7 child cells (aperture 7). +) -> set[str]: + """Get cells at a specified refinement level. Parameters ---------- - resolutions : list[int] - Levels of mesh resolution to consider. - flat : bool - Whether or not all resolution levels of the mesh are included. + resolution : int + The H3 refinement level. It can be an integer from 0 to 15. area : dict - A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global - mesh. + A region, in GeoJSON data format, to be contained by all cells. Defaults to None. aoi_mask_builder : KNNAreaMaskBuilder, optional - KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. + KNNAreaMaskBuilder computes nask to limit the mesh area, by default None. Returns ------- - graph : networkx.Graph - The specified graph (nodes & edges). + cells : set[str] + The set of H3 indexes at the specified resolution level. """ - graph = nx.Graph() - - area_kwargs = {"area": area} - - for resolution in resolutions: - add_nodes_for_resolution(graph, resolution, **area_kwargs) - - coords = np.array([h3.h3_to_geo(node) for node in graph.nodes]) - - # Sort nodes by latitude and longitude - node_ordering = np.lexsort(coords.T[::-1], axis=0) + # TODO: What is area? + cells = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) - # Should these be sorted here or in the edge builder? - coords = coords[node_ordering] + # TODO: AOI not used in the current implementation. - return graph, coords, node_ordering + return cells def add_edges_to_nx_graph( @@ -212,6 +163,28 @@ def add_self_loops(graph: nx.Graph) -> None: add_edge(graph, idx, idx, allow_self_loop=True) +def add_neighbour_edges( + graph: nx.Graph, + refinement_levels: tuple[int], + flat: bool = True, +) -> None: + for resolution in refinement_levels: + cells = {node for node in graph.nodes if h3.h3_get_resolution(node) == resolution} + for idx in cells: + k = 2 if resolution == 0 else 1 # refinement_levels[0]: # extra large field of vision ; only few nodes + + # neighbours + for idx_neighbour in h3.k_ring(idx, k=k) & cells: + if flat: + add_edge( + graph, + h3.h3_to_center_child(idx, refinement_levels[-1]), + h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]), + ) + else: + add_edge(graph, idx, idx_neighbour) + + def add_children_edges( graph: nx.Graph, refinement_levels: tuple[int], @@ -250,3 +223,33 @@ def add_children_edges( ) else: add_edge(graph, idx_parent, idx_child_neighbour) + + +def add_edge( + graph: nx.Graph, + idx1: str, + idx2: str, + allow_self_loop: bool = False, +) -> None: + """Add edge between two nodes to a graph. + + The edge will only be added in case both tail and head nodes are included in the graph, G. + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + idx1 : str + The H3 index of the tail of the edge. + idx2 : str + The H3 index of the head of the edge. + allow_self_loop : bool + Whether to allow self-loops or not. Defaults to not allowing self-loops. + """ + if not graph.has_node(idx1) or not graph.has_node(idx2): + return + + if allow_self_loop or idx1 != idx2: + loc1 = np.deg2rad(h3.h3_to_geo(idx1)) + loc2 = np.deg2rad(h3.h3_to_geo(idx2)) + graph.add_edge(idx1, idx2, weight=haversine_distances([loc1, loc2])[0][1]) diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 97c4801..7f124cc 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Iterable from typing import Optional @@ -8,7 +9,6 @@ from sklearn.neighbors import BallTree from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad -import logging logger = logging.getLogger(__name__) @@ -37,7 +37,7 @@ def create_icosahedral_nodes( Order of the nodes in the graph to be sorted by latitude and longitude. """ sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) - + coords_rad = cartesian_to_latlon_rad(sphere.vertices) node_ordering = get_node_ordering(coords_rad) @@ -61,11 +61,11 @@ def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_orderin return graph -def get_node_ordering(vertices_rad: np.ndarray) -> np.ndarray: +def get_node_ordering(coords_rad: np.ndarray) -> np.ndarray: # Get indices to sort points by lon & lat in radians. - ind1 = np.argsort(vertices_rad[:, 1]) - ind2 = np.argsort(vertices_rad[ind1][:, 0])[::-1] - node_ordering = np.arange(vertices_rad.shape[0])[ind1][ind2] + ind1 = np.argsort(coords_rad[:, 1]) + ind2 = np.argsort(coords_rad[ind1][:, 0])[::-1] + node_ordering = np.arange(coords_rad.shape[0])[ind1][ind2] return node_ordering @@ -108,10 +108,10 @@ def add_edges_to_nx_graph( # TODO AOI mask builder is not used in the current implementation. valid_nodes = None - x_rings = get_x_hops(r_sphere, xhops, valid_nodes=valid_nodes) + x_hops = get_x_hops(r_sphere, xhops, valid_nodes=valid_nodes) _, idx = tree.query(r_vertices_rad, k=1) - for i, i_neighbours in x_rings.items(): + for i, i_neighbours in x_hops.items(): add_neigbours_edges(graph, r_vertices_rad, i, i_neighbours, idx=idx) return graph @@ -151,8 +151,8 @@ def get_x_hops(tri_mesh: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[ def add_neigbours_edges( graph: nx.Graph, vertices: np.ndarray, - ii: int, - neighbours: Iterable[int], + node_idx: int, + neighbour_indices: Iterable[int], self_loops: bool = False, idx: Optional[np.ndarray] = None, ) -> None: @@ -164,7 +164,7 @@ def add_neigbours_edges( The graph. vertices : np.ndarray A 2D array of shape (num_vertices, 2) with the planar coordinates of the mesh, in radians. - ii : int + node_idx : int The node considered. neighbours : list[int] The neighbours of the node. @@ -173,26 +173,21 @@ def add_neigbours_edges( idx : np.ndarray, optional Index to map the vertices from the refined sphere to the original one, by default None. """ - for idx_neighbour in neighbours: - if not self_loops and ii == idx_neighbour: # no self-loops + for neighbour_idx in neighbour_indices: + if not self_loops and node_idx == neighbour_idx: # no self-loops continue - location_node = vertices[ii] - location_neighbour = vertices[idx_neighbour] + location_node = vertices[node_idx] + location_neighbour = vertices[neighbour_idx] edge_length = haversine_distances([location_neighbour, location_node])[0][1] if idx is not None: # Use the same method to add edge in all spheres - node_neighbour = idx[idx_neighbour][0] - node = idx[ii][0] + node_neighbour = idx[neighbour_idx][0] + node = idx[node_idx][0] else: - node, node_neighbour = ii, idx_neighbour + node, node_neighbour = node_idx, neighbour_idx # add edge to the graph if node in graph and node_neighbour in graph: graph.add_edge(node_neighbour, node, weight=edge_length) - - - - - From 9f2c0521f9c32bf51a076794723a9697c734d996 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 16:49:46 +0000 Subject: [PATCH 038/156] feat: ensure src and dst always the same --- src/anemoi/graphs/edges/builder.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 28b40c1..e9be430 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -14,8 +14,8 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral -from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder -from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder +from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodes +from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodes from anemoi.graphs.utils import get_grid_reference_distance logger = logging.getLogger(__name__) @@ -129,8 +129,8 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): class TriIcosahedralEdges(BaseEdgeBuilder): """Computes icosahedral edges and adds them to a HeteroData graph.""" - def __init__(self, src_name: str, dst_name: str, xhops: int): - super().__init__(src_name, dst_name) + def __init__(self, src_name: str, xhops: int): + super().__init__(src_name, src_name) assert isinstance(xhops, int), "Number of xhops must be an integer" assert xhops > 0, "Number of xhops must be positive" @@ -140,11 +140,8 @@ def __init__(self, src_name: str, dst_name: str, xhops: int): def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( - graph[self.src_name].node_type == TriRefinedIcosahedralNodeBuilder.__name__ - ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." - assert ( - graph[self.src_name] == graph[self.dst_name] - ), "InheritConnection requires the same nodes for source and destination." + graph[self.src_name].node_type == TriRefinedIcosahedralNodes.__name__ + ), "IcosahedralConnection requires TriRefinedIcosahedralNodes." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." @@ -173,20 +170,15 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): class HexagonalEdges(BaseEdgeBuilder): """Computes hexagonal edges and adds them to a HeteroData graph.""" - def __init__( - self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1 - ): - super().__init__(src_name, dst_name) + def __init__(self, src_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): + super().__init__(src_name, src_name) self.add_neighbouring_children = add_neighbouring_children self.depth_children = depth_children def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( - graph[self.src_name].node_type == HexRefinedIcosahedralNodeBuilder.__name__ - ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." - assert ( - graph[self.src_name] == graph[self.dst_name] - ), "InheritConnection requires the same nodes for source and destination." + graph[self.src_name].node_type == HexRefinedIcosahedralNodes.__name__ + ), "HexagonalEdges requires HexRefinedIcosahedralNodes." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." From e410bf58c997686d517698cf00f0424545e33c9e Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 16:50:01 +0000 Subject: [PATCH 039/156] fix: imports --- src/anemoi/graphs/edges/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 53b9c74..e95e3cd 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -1,4 +1,6 @@ from .builder import CutOffEdges +from .builder import HexagonalEdges from .builder import KNNEdges +from .builder import TriIcosahedralEdges -__all__ = ["KNNEdges", "CutOffEdges"] +__all__ = ["KNNEdges", "CutOffEdges", "TriIcosahedralEdges", "HexagonalEdges"] From ef1c110aa2c1092eb89aa1397e0e77cc699b11e3 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 17:22:41 +0000 Subject: [PATCH 040/156] fix: edge_name not supported --- src/anemoi/graphs/edges/builder.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index e9be430..3e31de6 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -137,7 +137,7 @@ def __init__(self, src_name: str, xhops: int): self.xhops = xhops - def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: + def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( graph[self.src_name].node_type == TriRefinedIcosahedralNodes.__name__ @@ -146,7 +146,7 @@ def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[Do # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." - return super().transform(graph, edge_name, attrs_config) + return super().transform(graph, attrs_config) def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): @@ -154,7 +154,6 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): src_nodes["nx_graph"], resolutions=src_nodes["resolutions"], xhops=self.xhops, - aoi_nneighb=None if "aoi_nneighb" not in src_nodes else src_nodes["aoi_nneigh"], ) # HeteroData refuses to accept None adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], nodelist=list(src_nodes["nx_graph"]), format="coo") @@ -175,7 +174,7 @@ def __init__(self, src_name: str, add_neighbouring_children: bool = False, depth self.add_neighbouring_children = add_neighbouring_children self.depth_children = depth_children - def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: + def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( graph[self.src_name].node_type == HexRefinedIcosahedralNodes.__name__ ), "HexagonalEdges requires HexRefinedIcosahedralNodes." @@ -183,7 +182,7 @@ def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[Do # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." - return super().transform(graph, edge_name, attrs_config) + return super().transform(graph, attrs_config) def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): From 2e6830fb1289cb60900191e530f0f1a1bcac904f Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 17:23:25 +0000 Subject: [PATCH 041/156] test: add tests for TriIcosahedralEdges --- tests/edges/test_tri_icosahedral_edges.py | 42 +++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/edges/test_tri_icosahedral_edges.py diff --git a/tests/edges/test_tri_icosahedral_edges.py b/tests/edges/test_tri_icosahedral_edges.py new file mode 100644 index 0000000..9663cb0 --- /dev/null +++ b/tests/edges/test_tri_icosahedral_edges.py @@ -0,0 +1,42 @@ +import pytest +from torch_geometric.data import HeteroData + +from anemoi.graphs.edges import TriIcosahedralEdges +from anemoi.graphs.nodes import TriRefinedIcosahedralNodes + + +class TestTriIcosahedralEdgesInit: + def test_init(self): + """Test TriIcosahedralEdges initialization.""" + assert isinstance(TriIcosahedralEdges("test_nodes", 1), TriIcosahedralEdges) + + @pytest.mark.parametrize("xhops", [-0.5, "hello", None, -4]) + def test_fail_init(self, xhops: str): + """Test TriIcosahedralEdges initialization with invalid cutoff.""" + with pytest.raises(AssertionError): + TriIcosahedralEdges("test_nodes", xhops) + + +class TestTriIcosahedralEdgesTransform: + + @pytest.fixture() + def ico_graph(self) -> HeteroData: + """Return a HeteroData object with TriRefinedIcosahedralNodes.""" + graph = HeteroData() + graph = TriRefinedIcosahedralNodes(0).transform(graph, "test_nodes", {}) + graph["fail_nodes"].x = [1, 2, 3] + graph["fail_nodes"].node_type = "FailNodes" + return graph + + def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData): + """Test TriIcosahedralEdges transform method.""" + + tri_icosahedral_edges = TriIcosahedralEdges("test_nodes", 1) + graph = tri_icosahedral_edges.transform(ico_graph) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types + + def test_transform_fail_nodes(self, ico_graph: HeteroData): + """Test TriIcosahedralEdges transform method with wrong node type.""" + tri_icosahedral_edges = TriIcosahedralEdges("fail_nodes", 1) + with pytest.raises(AssertionError): + tri_icosahedral_edges.transform(ico_graph) From a59f5d16e2f5997094745664f38d0b69f349f31b Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 2 Jul 2024 08:45:13 +0000 Subject: [PATCH 042/156] fix: assert missing for Hexagonal edges --- src/anemoi/graphs/edges/builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 3e31de6..70565ca 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -172,6 +172,9 @@ class HexagonalEdges(BaseEdgeBuilder): def __init__(self, src_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): super().__init__(src_name, src_name) self.add_neighbouring_children = add_neighbouring_children + + assert isinstance(depth_children, int), "Depth of children must be an integer" + assert depth_children > 0, "Depth of children must be positive" self.depth_children = depth_children def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: From 59bac560f5d5dc2f9cb3be0bc22eabca1dcb49c1 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 2 Jul 2024 08:46:08 +0000 Subject: [PATCH 043/156] test: hexagonal edges --- tests/edges/test_hex_refined_icosahedral.py | 42 +++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/edges/test_hex_refined_icosahedral.py diff --git a/tests/edges/test_hex_refined_icosahedral.py b/tests/edges/test_hex_refined_icosahedral.py new file mode 100644 index 0000000..d6bb29c --- /dev/null +++ b/tests/edges/test_hex_refined_icosahedral.py @@ -0,0 +1,42 @@ +import pytest +from torch_geometric.data import HeteroData + +from anemoi.graphs.edges import HexagonalEdges +from anemoi.graphs.nodes import HexRefinedIcosahedralNodes + + +class TestTriIcosahedralEdgesInit: + def test_init(self): + """Test TriIcosahedralEdges initialization.""" + assert isinstance(HexagonalEdges("test_nodes"), HexagonalEdges) + + @pytest.mark.parametrize("depth_children", [-0.5, "hello", None, -4]) + def test_fail_init(self, depth_children: str): + """Test HexagonalEdges initialization with invalid cutoff.""" + with pytest.raises(AssertionError): + HexagonalEdges("test_nodes", True, depth_children) + + +class TestTriIcosahedralEdgesTransform: + + @pytest.fixture() + def ico_graph(self) -> HeteroData: + """Return a HeteroData object with HexRefinedIcosahedralNodes.""" + graph = HeteroData() + graph = HexRefinedIcosahedralNodes(0).transform(graph, "test_nodes", {}) + graph["fail_nodes"].x = [1, 2, 3] + graph["fail_nodes"].node_type = "FailNodes" + return graph + + def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData): + """Test HexagonalEdges transform method.""" + + tri_icosahedral_edges = HexagonalEdges("test_nodes") + graph = tri_icosahedral_edges.transform(ico_graph) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types + + def test_transform_fail_nodes(self, ico_graph: HeteroData): + """Test HexagonalEdges transform method with wrong node type.""" + tri_icosahedral_edges = HexagonalEdges("fail_nodes") + with pytest.raises(AssertionError): + tri_icosahedral_edges.transform(ico_graph) From bd729c9ddf5131ec9de3808b515e5effd92d3cb7 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 2 Jul 2024 09:49:12 +0000 Subject: [PATCH 044/156] fix: avoid same name --- .../{test_hex_refined_icosahedral.py => test_hexagonal_edges.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/edges/{test_hex_refined_icosahedral.py => test_hexagonal_edges.py} (100%) diff --git a/tests/edges/test_hex_refined_icosahedral.py b/tests/edges/test_hexagonal_edges.py similarity index 100% rename from tests/edges/test_hex_refined_icosahedral.py rename to tests/edges/test_hexagonal_edges.py From 9cce37adb3920c960306eac2bd2b8a9d947c1220 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 3 Jul 2024 13:10:20 +0000 Subject: [PATCH 045/156] feat: LimitedAreaZarrNodes --- src/anemoi/graphs/nodes/__init__.py | 9 ++++++++- src/anemoi/graphs/nodes/builder.py | 23 +++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index ef13e41..07e4b0c 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,6 +1,13 @@ from .builder import HexRefinedIcosahedralNodes +from .builder import LimitedAreaZarrDatasetNodes from .builder import NPZFileNodes from .builder import TriRefinedIcosahedralNodes from .builder import ZarrDatasetNodes -__all__ = ["ZarrDatasetNodes", "NPZFileNodes", "TriRefinedIcosahedralNodes", "HexRefinedIcosahedralNodes"] +__all__ = [ + "ZarrDatasetNodes", + "NPZFileNodes", + "TriRefinedIcosahedralNodes", + "HexRefinedIcosahedralNodes", + "LimitedAreaZarrDatasetNodes", +] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index a477020..ad8d098 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -49,10 +49,29 @@ class ZarrDatasetNodes(BaseNodeBuilder): def __init__(self, dataset: DotDict) -> None: logger.info("Reading the dataset from %s.", dataset) - self.ds = open_dataset(dataset) + self.dataset = open_dataset(dataset) def get_coordinates(self) -> torch.Tensor: - return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) + return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) + + +class LimitedAreaZarrDatasetNodes(ZarrDatasetNodes): + """Nodes from Zarr dataset.""" + + def __init__(self, lam_dataset: str, forcing_dataset: str, thinning: int = 1, adjust: str = "all") -> None: + dataset_config = { + "cutout": [{"dataset": lam_dataset, "thinning": thinning}, {"dataset": forcing_dataset}], + "adjust": adjust, + } + super().__init__(dataset_config) + self.n_cutout, self.n_other = self.dataset.grids + + def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> None: + # this is a mask to cutout the LAM area + graph[name]["cutout"] = torch.tensor([True] * self.n_cutout + [False] * self.n_other, dtype=bool).reshape( + (-1, 1) + ) + return super().register_attributes(graph, name, config) class NPZFileNodes(BaseNodeBuilder): From 745709f08dd83be3b8e5eb94fa25e9d179f0a147 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 3 Jul 2024 13:24:20 +0000 Subject: [PATCH 046/156] feat: add KNNMaskBuilder for use with LAM --- src/anemoi/graphs/nodes/builder.py | 71 +++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index ad8d098..1f0292f 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -9,8 +9,10 @@ from anemoi.datasets import open_dataset from anemoi.utils.config import DotDict from hydra.utils import instantiate +from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData +from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes @@ -20,6 +22,9 @@ class BaseNodeBuilder(ABC): """Base class for node builders.""" + def __init__(self) -> None: + self.aoi_mask_builder = None + def register_nodes(self, graph: HeteroData, name: str) -> None: graph[name].x = self.get_coordinates() graph[name].node_type = type(self).__name__ @@ -126,7 +131,7 @@ def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> graph[name]["resolutions"] = self.resolutions graph[name]["nx_graph"] = self.nx_graph graph[name]["node_ordering"] = self.node_ordering - # TODO: AOI mask builder is not used in the current implementation. + graph[name]["aoi_mask_builder"] = self.aoi_mask_builder return super().register_attributes(graph, name, config) @@ -144,3 +149,67 @@ class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes): def create_nodes(self) -> np.ndarray: # TODO: AOI mask builder is not used in the current implementation. return create_hexagonal_nodes(self.resolutions) + + +class AreaTriRefinedIcosahedralNodeBuilder(TriRefinedIcosahedralNodes): + """Class to build icosahedral nodes with a limited area of interest.""" + + def __init__( + self, + resolution: int | list[int], + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + np_dtype: np.dtype = np.float32, + ) -> None: + + super().__init__(resolution, np_dtype) + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + def register_nodes(self, graph: HeteroData, name: str) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph, name) + + +class AreaHexRefinedIcosahedralNodeBuilder(HexRefinedIcosahedralNodes): + """Class to build icosahedral nodes with a limited area of interest.""" + + def __init__( + self, + resolution: int | list[int], + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + np_dtype: np.dtype = np.float32, + ) -> None: + + super().__init__(resolution, np_dtype) + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + def register_nodes(self, graph: HeteroData, name: str) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph, name) + + +class KNNAreaMaskBuilder: + """Class to build a mask based on distance to masked reference nodes using KNN.""" + + def __init__(self, reference_node_name: str, margin_radius_km: float, mask_attr_name: str): + + self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + self.margin_radius_km = margin_radius_km + self.reference_node_name = reference_node_name + self.mask_attr_name = mask_attr_name + + def fit(self, graph: HeteroData): + coords_rad = graph[self.reference_node_name].x.numpy() + mask = graph[self.reference_node_name].mask_attr_name + self.nearest_neighbour.fit(coords_rad[mask]) + + def get_mask(self, coords_rad: np.ndarray): + + neigh_dists, _ = self.nearest_neighbour.kneighbors(coords_rad, n_neighbors=1) + mask = neigh_dists[:, 0] * EARTH_RADIUS <= self.margin_radius_km + return mask From bc735cd7da13706ef7e76ccf07423ce232352b50 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 3 Jul 2024 13:38:37 +0000 Subject: [PATCH 047/156] feat: add KNNMaskBuilder to TriIcosahedral --- src/anemoi/graphs/edges/builder.py | 1 + src/anemoi/graphs/generate/icosahedral.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 70565ca..a9ee46b 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -154,6 +154,7 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): src_nodes["nx_graph"], resolutions=src_nodes["resolutions"], xhops=self.xhops, + aoi_nneighb=None if "aoi_mask_builder" not in src_nodes else src_nodes["aoi_mask_builder"], ) # HeteroData refuses to accept None adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], nodelist=list(src_nodes["nx_graph"]), format="coo") diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 7f124cc..e96595e 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -9,12 +9,13 @@ from sklearn.neighbors import BallTree from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad +from anemoi.graphs.nodes import KNNAreaMaskBuilder logger = logging.getLogger(__name__) def create_icosahedral_nodes( - resolutions: list[int], + resolutions: list[int], aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None ) -> tuple[nx.DiGraph, np.ndarray, list[int]]: """Creates a global mesh following AIFS strategy. @@ -42,7 +43,9 @@ def create_icosahedral_nodes( node_ordering = get_node_ordering(coords_rad) - # TODO: AOI mask builder is not used in the current implementation. + if aoi_mask_builder is not None: + aoi_mask = aoi_mask_builder.get_mask(coords_rad) + node_ordering = node_ordering[aoi_mask] nx_graph = create_icosahedral_nx_graph_from_coords(coords_rad, node_ordering) @@ -73,6 +76,7 @@ def add_edges_to_nx_graph( graph: nx.DiGraph, resolutions: list[int], xhops: int = 1, + aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, ) -> None: """Adds the edges to the graph. @@ -86,8 +90,6 @@ def add_edges_to_nx_graph( Number of hops between 2 nodes to consider them neighbours, by default 1. aoi_mask_builder : KNNAreaMaskBuilder NearestNeighbors with the cloud of points to limit the mesh area, by default None. - margin_radius_km : float, optional - Margin radius in km to consider when creating the processor mesh, by default 0.0. """ assert xhops > 0, "xhops == 0, graph would have no edges ..." @@ -105,8 +107,12 @@ def add_edges_to_nx_graph( r_sphere = trimesh.creation.icosphere(subdivisions=resolution, radius=1.0) r_vertices_rad = cartesian_to_latlon_rad(r_sphere.vertices) - # TODO AOI mask builder is not used in the current implementation. - valid_nodes = None + # Limit area of mesh points. + if aoi_mask_builder is not None: + aoi_mask = aoi_mask_builder.get_mask(vertices_rad) + valid_nodes = np.where(aoi_mask)[0] + else: + valid_nodes = None x_hops = get_x_hops(r_sphere, xhops, valid_nodes=valid_nodes) From 03fbf9f5ca5c618a492c2340749ec56000813f74 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 3 Jul 2024 13:50:19 +0000 Subject: [PATCH 048/156] feat: AreaNPZFileNodes --- src/anemoi/graphs/nodes/builder.py | 36 ++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 1f0292f..c0a3cfe 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -92,6 +92,42 @@ def get_coordinates(self) -> np.ndarray: return coords +class AreaNPZFileNodes(NPZFileNodes): + """Processor mesh based on an NPZ defined grids using an area of interest.""" + + def __init__( + self, + resolution: str, + grid_definition_path: str, + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + np_dtype: np.dtype = np.float32, + ) -> None: + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + super().__init__(resolution, grid_definition_path, np_dtype) + + def register_nodes(self, graph: HeteroData, name: str) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph, name) + + def get_coordinates(self) -> np.ndarray: + coords = super().get_coordinates() + + logger.info( + "Limiting the processor mesh to a radius of %.2f km from the output mesh.", + self.aoi_mask_builder.margin_radius_km, + ) + aoi_mask = self.aoi_mask_builder.get_mask(np.deg2rad(coords)) + + logger.info("Dropping %d nodes from the processor mesh.", len(aoi_mask) - aoi_mask.sum()) + coords = coords[aoi_mask] + + return coords + + class RefinedIcosahedralNodes(BaseNodeBuilder, ABC): """Processor mesh based on a triangular mesh. From 3731e826c4295748d3a58fbd64ad441eca18dd81 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 3 Jul 2024 16:29:57 +0000 Subject: [PATCH 049/156] fix: KNNAreaMaskBuilder working with NPZ --- src/anemoi/graphs/generate/icosahedral.py | 2 +- src/anemoi/graphs/nodes/__init__.py | 2 ++ src/anemoi/graphs/nodes/builder.py | 28 ++--------------------- src/anemoi/graphs/nodes/masks.py | 27 ++++++++++++++++++++++ 4 files changed, 32 insertions(+), 27 deletions(-) create mode 100644 src/anemoi/graphs/nodes/masks.py diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index e96595e..828877a 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -9,7 +9,7 @@ from sklearn.neighbors import BallTree from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad -from anemoi.graphs.nodes import KNNAreaMaskBuilder +from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder logger = logging.getLogger(__name__) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 07e4b0c..21f92a1 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,3 +1,4 @@ +from .builder import AreaNPZFileNodes from .builder import HexRefinedIcosahedralNodes from .builder import LimitedAreaZarrDatasetNodes from .builder import NPZFileNodes @@ -10,4 +11,5 @@ "TriRefinedIcosahedralNodes", "HexRefinedIcosahedralNodes", "LimitedAreaZarrDatasetNodes", + "AreaNPZFileNodes", ] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index c0a3cfe..d4a0324 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -9,12 +9,11 @@ from anemoi.datasets import open_dataset from anemoi.utils.config import DotDict from hydra.utils import instantiate -from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData -from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes +from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder logger = logging.getLogger(__name__) @@ -102,12 +101,11 @@ def __init__( reference_node_name: str, mask_attr_name: str, margin_radius_km: float = 100.0, - np_dtype: np.dtype = np.float32, ) -> None: self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) - super().__init__(resolution, grid_definition_path, np_dtype) + super().__init__(resolution, grid_definition_path) def register_nodes(self, graph: HeteroData, name: str) -> None: self.aoi_mask_builder.fit(graph) @@ -227,25 +225,3 @@ def __init__( def register_nodes(self, graph: HeteroData, name: str) -> None: self.aoi_mask_builder.fit(graph) return super().register_nodes(graph, name) - - -class KNNAreaMaskBuilder: - """Class to build a mask based on distance to masked reference nodes using KNN.""" - - def __init__(self, reference_node_name: str, margin_radius_km: float, mask_attr_name: str): - - self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - self.margin_radius_km = margin_radius_km - self.reference_node_name = reference_node_name - self.mask_attr_name = mask_attr_name - - def fit(self, graph: HeteroData): - coords_rad = graph[self.reference_node_name].x.numpy() - mask = graph[self.reference_node_name].mask_attr_name - self.nearest_neighbour.fit(coords_rad[mask]) - - def get_mask(self, coords_rad: np.ndarray): - - neigh_dists, _ = self.nearest_neighbour.kneighbors(coords_rad, n_neighbors=1) - mask = neigh_dists[:, 0] * EARTH_RADIUS <= self.margin_radius_km - return mask diff --git a/src/anemoi/graphs/nodes/masks.py b/src/anemoi/graphs/nodes/masks.py new file mode 100644 index 0000000..786796f --- /dev/null +++ b/src/anemoi/graphs/nodes/masks.py @@ -0,0 +1,27 @@ +import numpy as np +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from anemoi.graphs import EARTH_RADIUS + + +class KNNAreaMaskBuilder: + """Class to build a mask based on distance to masked reference nodes using KNN.""" + + def __init__(self, reference_node_name: str, margin_radius_km: float, mask_attr_name: str): + + self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + self.margin_radius_km = margin_radius_km + self.reference_node_name = reference_node_name + self.mask_attr_name = mask_attr_name + + def fit(self, graph: HeteroData): + coords_rad = graph[self.reference_node_name].x.numpy() + mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() + self.nearest_neighbour.fit(coords_rad[mask, :]) + + def get_mask(self, coords_rad: np.ndarray): + + neigh_dists, _ = self.nearest_neighbour.kneighbors(coords_rad, n_neighbors=1) + mask = neigh_dists[:, 0] * EARTH_RADIUS <= self.margin_radius_km + return mask From ed64a7ebb0f9eae393471d48fb7ee736f8dd7fe5 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 3 Jul 2024 16:41:08 +0000 Subject: [PATCH 050/156] fix: imports and naming --- src/anemoi/graphs/nodes/__init__.py | 4 ++++ src/anemoi/graphs/nodes/builder.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 21f92a1..e515410 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,6 @@ +from .builder import AreaHexRefinedIcosahedralNodes from .builder import AreaNPZFileNodes +from .builder import AreaTriRefinedIcosahedralNodes from .builder import HexRefinedIcosahedralNodes from .builder import LimitedAreaZarrDatasetNodes from .builder import NPZFileNodes @@ -12,4 +14,6 @@ "HexRefinedIcosahedralNodes", "LimitedAreaZarrDatasetNodes", "AreaNPZFileNodes", + "AreaTriRefinedIcosahedralNodes", + "AreaTriRefinedIcosahedralNodes", ] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index d4a0324..26ee2b7 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -185,7 +185,7 @@ def create_nodes(self) -> np.ndarray: return create_hexagonal_nodes(self.resolutions) -class AreaTriRefinedIcosahedralNodeBuilder(TriRefinedIcosahedralNodes): +class AreaTriRefinedIcosahedralNodes(TriRefinedIcosahedralNodes): """Class to build icosahedral nodes with a limited area of interest.""" def __init__( @@ -206,7 +206,7 @@ def register_nodes(self, graph: HeteroData, name: str) -> None: return super().register_nodes(graph, name) -class AreaHexRefinedIcosahedralNodeBuilder(HexRefinedIcosahedralNodes): +class AreaHexRefinedIcosahedralNodes(HexRefinedIcosahedralNodes): """Class to build icosahedral nodes with a limited area of interest.""" def __init__( From 1e2c37a9b3d1a5ea865bac141cdbd087af3c1ea2 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 5 Jul 2024 09:48:53 +0000 Subject: [PATCH 051/156] fix: TriIocsahedral working for area masks --- src/anemoi/graphs/edges/builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index a9ee46b..61ebbcd 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -14,6 +14,7 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral +from anemoi.graphs.nodes.builder import AreaTriRefinedIcosahedralNodes from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodes from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodes from anemoi.graphs.utils import get_grid_reference_distance @@ -141,6 +142,7 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) - assert ( graph[self.src_name].node_type == TriRefinedIcosahedralNodes.__name__ + or graph[self.src_name].node_type == AreaTriRefinedIcosahedralNodes.__name__ ), "IcosahedralConnection requires TriRefinedIcosahedralNodes." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate @@ -154,7 +156,7 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): src_nodes["nx_graph"], resolutions=src_nodes["resolutions"], xhops=self.xhops, - aoi_nneighb=None if "aoi_mask_builder" not in src_nodes else src_nodes["aoi_mask_builder"], + aoi_mask_builder=None if "aoi_mask_builder" not in src_nodes else src_nodes["aoi_mask_builder"], ) # HeteroData refuses to accept None adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], nodelist=list(src_nodes["nx_graph"]), format="coo") From c07c583529311cfeafc0ba059164335c50a057e7 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 5 Jul 2024 09:49:20 +0000 Subject: [PATCH 052/156] feat: debugging purposes --- src/anemoi/graphs/create.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 57b60bc..586761e 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -1,9 +1,12 @@ import logging import os +import hydra import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate +from omegaconf import DictConfig +from omegaconf import OmegaConf from torch_geometric.data import HeteroData logger = logging.getLogger(__name__) @@ -66,3 +69,13 @@ def _path_readable(self) -> bool: return True except FileNotFoundError: return False + + +@hydra.main(version_base=None, config_path="../../../config", config_name="graph_recipe_lam.yaml") +def main(config: DictConfig) -> None: + OmegaConf.resolve(config) + GraphCreator("graph.pt", config, overwrite=True).create() + + +if __name__ == "__main__": + main() From 980ed8d7b04739eef99f07f585b29db799f974ac Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 5 Jul 2024 10:07:30 +0000 Subject: [PATCH 053/156] refactor: rename tests --- tests/edges/{test_cutoff.py => test_cutoff_edges.py} | 0 tests/edges/{test_attributes.py => test_edge_attributes.py} | 0 tests/edges/{test_knn.py => test_knn_edges.py} | 0 ...t_hex_refined_icosahedral.py => test_hex_icosahedral_nodes.py} | 0 tests/nodes/{test_npz.py => test_npz_nodes.py} | 0 ...t_tri_refined_icosahedral.py => test_tri_icosahedral_nodes.py} | 0 tests/nodes/{test_zarr.py => test_zarr_nodes.py} | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename tests/edges/{test_cutoff.py => test_cutoff_edges.py} (100%) rename tests/edges/{test_attributes.py => test_edge_attributes.py} (100%) rename tests/edges/{test_knn.py => test_knn_edges.py} (100%) rename tests/nodes/{test_hex_refined_icosahedral.py => test_hex_icosahedral_nodes.py} (100%) rename tests/nodes/{test_npz.py => test_npz_nodes.py} (100%) rename tests/nodes/{test_tri_refined_icosahedral.py => test_tri_icosahedral_nodes.py} (100%) rename tests/nodes/{test_zarr.py => test_zarr_nodes.py} (100%) diff --git a/tests/edges/test_cutoff.py b/tests/edges/test_cutoff_edges.py similarity index 100% rename from tests/edges/test_cutoff.py rename to tests/edges/test_cutoff_edges.py diff --git a/tests/edges/test_attributes.py b/tests/edges/test_edge_attributes.py similarity index 100% rename from tests/edges/test_attributes.py rename to tests/edges/test_edge_attributes.py diff --git a/tests/edges/test_knn.py b/tests/edges/test_knn_edges.py similarity index 100% rename from tests/edges/test_knn.py rename to tests/edges/test_knn_edges.py diff --git a/tests/nodes/test_hex_refined_icosahedral.py b/tests/nodes/test_hex_icosahedral_nodes.py similarity index 100% rename from tests/nodes/test_hex_refined_icosahedral.py rename to tests/nodes/test_hex_icosahedral_nodes.py diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz_nodes.py similarity index 100% rename from tests/nodes/test_npz.py rename to tests/nodes/test_npz_nodes.py diff --git a/tests/nodes/test_tri_refined_icosahedral.py b/tests/nodes/test_tri_icosahedral_nodes.py similarity index 100% rename from tests/nodes/test_tri_refined_icosahedral.py rename to tests/nodes/test_tri_icosahedral_nodes.py diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr_nodes.py similarity index 100% rename from tests/nodes/test_zarr.py rename to tests/nodes/test_zarr_nodes.py From 3609681271b722f5d3ec8a151fb28d78d04c6772 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Fri, 5 Jul 2024 17:36:50 +0200 Subject: [PATCH 054/156] Global Encoder-Processor-Decoder graph (#9) * feat: Initial implementation of global graphs Co-authored by: Mario Santa Cruz Co-authored-by: Helen Theissen Co-authored-by: Sara Hahner Co-authored-by: Jesper Dramsch --- .gitignore | 3 + README.md | 8 + pyproject.toml | 8 +- src/anemoi/graphs/__init__.py | 5 +- src/anemoi/graphs/commands/create.py | 28 +++ src/anemoi/graphs/commands/hello.py | 32 ---- src/anemoi/graphs/create.py | 80 ++++++++ src/anemoi/graphs/edges/__init__.py | 4 + src/anemoi/graphs/edges/attributes.py | 148 +++++++++++++++ src/anemoi/graphs/edges/builder.py | 221 +++++++++++++++++++++++ src/anemoi/graphs/edges/directional.py | 85 +++++++++ src/anemoi/graphs/generate/__init__.py | 0 src/anemoi/graphs/generate/transforms.py | 95 ++++++++++ src/anemoi/graphs/nodes/__init__.py | 4 + src/anemoi/graphs/nodes/attributes.py | 115 ++++++++++++ src/anemoi/graphs/nodes/builder.py | 168 +++++++++++++++++ src/anemoi/graphs/normalizer.py | 44 +++++ src/anemoi/graphs/utils.py | 133 ++++++++++++++ tests/conftest.py | 92 ++++++++++ tests/edges/test_cutoff.py | 22 +++ tests/edges/test_edge_attributes.py | 29 +++ tests/edges/test_knn.py | 22 +++ tests/nodes/test_node_attributes.py | 42 +++++ tests/nodes/test_npz.py | 58 ++++++ tests/nodes/test_zarr.py | 50 +++++ tests/test_graphs.py | 30 ++- tests/test_normalizer.py | 55 ++++++ 27 files changed, 1542 insertions(+), 39 deletions(-) create mode 100644 src/anemoi/graphs/commands/create.py delete mode 100644 src/anemoi/graphs/commands/hello.py create mode 100644 src/anemoi/graphs/create.py create mode 100644 src/anemoi/graphs/edges/__init__.py create mode 100644 src/anemoi/graphs/edges/attributes.py create mode 100644 src/anemoi/graphs/edges/builder.py create mode 100644 src/anemoi/graphs/edges/directional.py create mode 100644 src/anemoi/graphs/generate/__init__.py create mode 100644 src/anemoi/graphs/generate/transforms.py create mode 100644 src/anemoi/graphs/nodes/__init__.py create mode 100644 src/anemoi/graphs/nodes/attributes.py create mode 100644 src/anemoi/graphs/nodes/builder.py create mode 100644 src/anemoi/graphs/normalizer.py create mode 100644 src/anemoi/graphs/utils.py create mode 100644 tests/conftest.py create mode 100644 tests/edges/test_cutoff.py create mode 100644 tests/edges/test_edge_attributes.py create mode 100644 tests/edges/test_knn.py create mode 100644 tests/nodes/test_node_attributes.py create mode 100644 tests/nodes/test_npz.py create mode 100644 tests/nodes/test_zarr.py create mode 100644 tests/test_normalizer.py diff --git a/.gitignore b/.gitignore index 2137d4c..1b49006 100644 --- a/.gitignore +++ b/.gitignore @@ -186,3 +186,6 @@ _build/ *.sync _version.py *.code-workspace + +/config* +*.pt diff --git a/README.md b/README.md index f55e7da..607c2d3 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,14 @@ Install via `pip` with: $ pip install anemoi-graphs ``` +## Usage + +Create your graph using the configuration given in the config file. The resulting graph will be saved in the given path. + +``` +$ anemoi-graphs create recipe.yaml my_graph.pt +``` + ## License ``` diff --git a/pyproject.toml b/pyproject.toml index 57fc0ff..cb5bb7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,11 @@ dynamic = [ "version", ] dependencies = [ - "anemoi-datasets", + "anemoi-datasets[data]>=0.3.3", + "anemoi-utils>=0.3.6", + "hydra-core>=1.3", + "torch>=2.2", + "torch-geometric>=2.3.1,<2.5", ] optional-dependencies.all = [ @@ -59,6 +63,7 @@ optional-dependencies.dev = [ "nbsphinx", "pandoc", "pytest", + "pytest-mock", "requests", "sphinx", "sphinx-argparse", @@ -80,6 +85,7 @@ optional-dependencies.docs = [ optional-dependencies.tests = [ "pytest", + "pytest-mock", ] urls.Documentation = "https://anemoi-graphs.readthedocs.io/" diff --git a/src/anemoi/graphs/__init__.py b/src/anemoi/graphs/__init__.py index eef2c1d..715b8a4 100644 --- a/src/anemoi/graphs/__init__.py +++ b/src/anemoi/graphs/__init__.py @@ -1,9 +1,10 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from ._version import __version__ -from ._version import __version__ as __version__ +EARTH_RADIUS = 6371.0 # km diff --git a/src/anemoi/graphs/commands/create.py b/src/anemoi/graphs/commands/create.py new file mode 100644 index 0000000..18b3127 --- /dev/null +++ b/src/anemoi/graphs/commands/create.py @@ -0,0 +1,28 @@ +from anemoi.graphs.create import GraphCreator + +from . import Command + + +class Create(Command): + """Create a graph.""" + + internal = True + timestamp = True + + def add_arguments(self, command_parser): + command_parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing files. This will delete the target graph if it already exists.", + ) + command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the graph.") + command_parser.add_argument("path", help="Path to store the created graph.") + + def run(self, args): + kwargs = vars(args) + + c = GraphCreator(**kwargs) + c.create() + + +command = Create diff --git a/src/anemoi/graphs/commands/hello.py b/src/anemoi/graphs/commands/hello.py deleted file mode 100644 index 12a0495..0000000 --- a/src/anemoi/graphs/commands/hello.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -"""Command place holder. Delete when we have real commands. - -""" - -from . import Command - - -def say_hello(greetings, who): - print(greetings, who) - - -class Hello(Command): - - def add_arguments(self, command_parser): - command_parser.add_argument("--greetings", default="hello") - command_parser.add_argument("--who", default="world") - - def run(self, args): - say_hello(args.greetings, args.who) - - -command = Hello diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py new file mode 100644 index 0000000..0b09649 --- /dev/null +++ b/src/anemoi/graphs/create.py @@ -0,0 +1,80 @@ +import logging +import os + +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + +LOGGER = logging.getLogger(__name__) + + +class GraphCreator: + """Graph creator.""" + + def __init__( + self, + path, + config=None, + cache=None, + print=print, + overwrite=False, + **kwargs, + ): + if isinstance(config, str) or isinstance(config, os.PathLike): + self.config = DotDict.from_file(config) + else: + self.config = config + + self.path = path # Output path + self.cache = cache + self.print = print + self.overwrite = overwrite + + def init(self): + if self._path_readable() and not self.overwrite: + raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") + + def generate_graph(self) -> HeteroData: + """Generate the graph. + + It instantiates the node builders and edge builders defined in the configuration + file and applies them to the graph. + + Returns + ------- + HeteroData: The generated graph. + """ + graph = HeteroData() + for name, nodes_cfg in self.config.nodes.items(): + graph = instantiate(nodes_cfg.node_builder).update_graph(graph, name, nodes_cfg.get("attributes", {})) + + for edges_cfg in self.config.edges: + graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).update_graph( + graph, edges_cfg.get("attributes", {}) + ) + + return graph + + def save(self, graph: HeteroData) -> None: + """Save the graph to the output path.""" + if not os.path.exists(self.path) or self.overwrite: + torch.save(graph, self.path) + self.print(f"Graph saved at {self.path}.") + + def create(self) -> HeteroData: + """Create the graph and save it to the output path.""" + self.init() + graph = self.generate_graph() + self.save(graph) + return graph + + def _path_readable(self) -> bool: + """Check if the output path is readable.""" + import torch + + try: + torch.load(self.path) + return True + except FileNotFoundError: + return False diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py new file mode 100644 index 0000000..53b9c74 --- /dev/null +++ b/src/anemoi/graphs/edges/__init__.py @@ -0,0 +1,4 @@ +from .builder import CutOffEdges +from .builder import KNNEdges + +__all__ = ["KNNEdges", "CutOffEdges"] diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py new file mode 100644 index 0000000..6945867 --- /dev/null +++ b/src/anemoi/graphs/edges/attributes.py @@ -0,0 +1,148 @@ +import logging +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.edges.directional import directional_edge_features +from anemoi.graphs.normalizer import NormalizerMixin +from anemoi.graphs.utils import haversine_distance + +LOGGER = logging.getLogger(__name__) + + +@dataclass +class BaseEdgeAttribute(ABC, NormalizerMixin): + """Base class for edge attributes.""" + + norm: Optional[str] = None + + @abstractmethod + def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ... + + def post_process(self, values: np.ndarray) -> torch.Tensor: + """Post-process the values.""" + if values.ndim == 1: + values = values[:, np.newaxis] + + return torch.tensor(values) + + def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> torch.Tensor: + """Compute the edge attributes.""" + assert ( + source_name in graph.node_types + ), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." + assert ( + target_name in graph.node_types + ), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." + + values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs) + normed_values = self.normalize(values) + return self.post_process(normed_values) + + +@dataclass +class EdgeDirection(BaseEdgeAttribute): + """Compute directional features for edges. + + If using the rotated features, the direction of the edge is computed + rotating the target nodes to the north pole. If not, it is computed + as the diference in latitude and longitude between the source and + target nodes. + + Attributes + ---------- + norm : Optional[str] + Normalization method. + luse_rotated_features : bool + Whether to use rotated features. + + Methods + ------- + get_raw_values(graph, source_name, target_name) + Compute directions between nodes connected by edges. + compute(graph, source_name, target_name) + Compute directional attributes. + """ + + norm: Optional[str] = None + luse_rotated_features: bool = True + + def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: + """Compute directional features for edges. + + Parameters + ---------- + graph : HeteroData + The graph. + source_name : str + The name of the source nodes. + target_name : str + The name of the target nodes. + + Returns + ------- + np.ndarray + The directional features. + """ + edge_index = graph[(source_name, "to", target_name)].edge_index + source_coords = graph[source_name].x.numpy()[edge_index[0]].T + target_coords = graph[target_name].x.numpy()[edge_index[1]].T + edge_dirs = directional_edge_features(source_coords, target_coords, self.luse_rotated_features).T + return edge_dirs + + +@dataclass +class EdgeLength(BaseEdgeAttribute): + """Edge length feature. + + Attributes + ---------- + norm : str + Normalization method. + invert : bool + Whether to invert the edge lengths, i.e. 1 - edge_length. + + Methods + ------- + get_raw_values(graph, source_name, target_name) + Compute haversine distance between nodes connected by edges. + compute(graph, source_name, target_name) + Compute edge lengths attributes. + """ + + norm: str = "l1" + invert: bool = True + + def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: + """Compute haversine distance (in kilometers) between nodes connected by edges. + + Parameters + ---------- + graph : HeteroData + The graph. + source_name : str + The name of the source nodes. + target_name : str + The name of the target nodes. + + Returns + ------- + np.ndarray + The edge lengths. + """ + edge_index = graph[(source_name, "to", target_name)].edge_index + source_coords = graph[source_name].x.numpy()[edge_index[0]] + target_coords = graph[target_name].x.numpy()[edge_index[1]] + edge_lengths = haversine_distance(source_coords, target_coords) + return edge_lengths + + def post_process(self, values: np.ndarray) -> torch.Tensor: + """Post-process edge lengths.""" + if self.invert: + values = 1 - values + return super().post_process(values) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py new file mode 100644 index 0000000..3c926d6 --- /dev/null +++ b/src/anemoi/graphs/edges/builder.py @@ -0,0 +1,221 @@ +import logging +from abc import ABC +from abc import abstractmethod +from typing import Optional + +import numpy as np +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData +from torch_geometric.data.storage import NodeStorage + +from anemoi.graphs import EARTH_RADIUS +from anemoi.graphs.utils import get_grid_reference_distance + +LOGGER = logging.getLogger(__name__) + + +class BaseEdgeBuilder(ABC): + + def __init__(self, source_name: str, target_name: str): + super().__init__() + self.source_name = source_name + self.target_name = target_name + + @abstractmethod + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ... + + def register_edges(self, graph: HeteroData, source_indices: np.ndarray, target_indices: np.ndarray) -> HeteroData: + """Register edges in the graph. + + Parameters + ---------- + graph : HeteroData + The graph to register the edges. + source_indices : np.ndarray of shape (N, ) + The indices of the source nodes. + target_indices : np.ndarray of shape (N, ) + The indices of the target nodes. + + Returns + ------- + HeteroData + The graph with the registered edges. + """ + edge_index = np.stack([source_indices, target_indices], axis=0).astype(np.int32) + graph[(self.source_name, "to", self.target_name)].edge_index = torch.from_numpy(edge_index) + return graph + + def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: + """Register attributes in the edges of the graph specified. + + Parameters + ---------- + graph : HeteroData + The graph to register the attributes. + config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with the registered attributes. + """ + for attr_name, attr_config in config.items(): + graph[self.source_name, "to", self.target_name][attr_name] = instantiate(attr_config).compute( + graph, self.source_name, self.target_name + ) + return graph + + def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: + """Prepare nodes information.""" + return graph[self.source_name], graph[self.target_name] + + def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + """Update the graph with the edges. + + Parameters + ---------- + graph : HeteroData + The graph. + attrs_config : DotDict + The configuration of the edge attributes. + + Returns + ------- + HeteroData + The graph with the edges. + """ + source_nodes, target_nodes = self.prepare_node_data(graph) + + adjmat = self.get_adjacency_matrix(source_nodes, target_nodes) + + graph = self.register_edges(graph, adjmat.col, adjmat.row) + + if attrs_config is None: + return graph + + graph = self.register_attributes(graph, attrs_config) + + return graph + + +class KNNEdges(BaseEdgeBuilder): + """Computes KNN based edges and adds them to the graph. + + Attributes + ---------- + source_name : str + The name of the source nodes. + target_name : str + The name of the target nodes. + num_nearest_neighbours : int + Number of nearest neighbours. + """ + + def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): + super().__init__(source_name, target_name) + assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" + assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" + self.num_nearest_neighbours = num_nearest_neighbours + + def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray): + """Compute the adjacency matrix for the KNN method. + + Parameters + ---------- + source_nodes : np.ndarray + The source nodes. + target_nodes : np.ndarray + The target nodes. + """ + assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" + LOGGER.info( + "Using KNN-Edges (with %d nearest neighbours) between %s and %s.", + self.num_nearest_neighbours, + self.source_name, + self.target_name, + ) + + nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + nearest_neighbour.fit(source_nodes.x.numpy()) + adj_matrix = nearest_neighbour.kneighbors_graph( + target_nodes.x.numpy(), + n_neighbors=self.num_nearest_neighbours, + mode="distance", + ).tocoo() + return adj_matrix + + +class CutOffEdges(BaseEdgeBuilder): + """Computes cut-off based edges and adds them to the graph. + + Attributes + ---------- + source_name : str + The name of the source nodes. + target_name : str + The name of the target nodes. + cutoff_factor : float + Factor to multiply the grid reference distance to get the cut-off radius. + radius : float + Cut-off radius. + """ + + def __init__(self, source_name: str, target_name: str, cutoff_factor: float): + super().__init__(source_name, target_name) + assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float" + assert cutoff_factor > 0, "Cutoff factor must be positive" + self.cutoff_factor = cutoff_factor + + def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] = None): + """Compute the cut-off radius. + + The cut-off radius is computed as the product of the target nodes reference distance and the cut-off factor. + + Parameters + ---------- + graph : HeteroData + The graph. + mask_attr : torch.Tensor + The mask attribute. + + Returns + ------- + float + The cut-off radius. + """ + target_nodes = graph[self.target_name] + mask = target_nodes[mask_attr] if mask_attr is not None else None + target_grid_reference_distance = get_grid_reference_distance(target_nodes.x, mask) + radius = target_grid_reference_distance * self.cutoff_factor + return radius + + def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: + """Prepare nodes information.""" + self.radius = self.get_cutoff_radius(graph) + return super().prepare_node_data(graph) + + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): + """Get the adjacency matrix for the cut-off method. + + Parameters + ---------- + source_nodes : NodeStorage + The source nodes. + target_nodes : NodeStorage + The target nodes. + """ + LOGGER.info( + "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", + self.radius * EARTH_RADIUS, + self.source_name, + self.target_name, + ) + + nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + nearest_neighbour.fit(source_nodes.x) + adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo() + return adj_matrix diff --git a/src/anemoi/graphs/edges/directional.py b/src/anemoi/graphs/edges/directional.py new file mode 100644 index 0000000..9c7cdea --- /dev/null +++ b/src/anemoi/graphs/edges/directional.py @@ -0,0 +1,85 @@ +from typing import Optional + +import numpy as np +from scipy.spatial.transform import Rotation + +from anemoi.graphs.generate.transforms import direction_vec +from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian + + +def get_rotation_from_unit_vecs(points: np.ndarray, reference: np.ndarray) -> Rotation: + """Compute rotation matrix of a set of points with respect to a reference vector. + + Parameters + ---------- + points : np.ndarray of shape (num_points, 3) + The points to compute the direction vector. + reference : np.ndarray of shape (3, ) + The reference vector. + + Returns + ------- + Rotation + The rotation matrix that aligns the points with the reference vector. + """ + assert points.shape[1] == 3, "Points must be in 3D" + v_unit = direction_vec(points, reference) + theta = np.arccos(np.dot(points, reference)) + return Rotation.from_rotvec(np.transpose(v_unit * theta)) + + +def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: Optional[np.ndarray] = None) -> np.ndarray: + """Compute the direction of the edge joining the nodes considered. + + Parameters + ---------- + loc1 : np.ndarray of shape (2, num_points) + Location of the head nodes. + loc2 : np.ndarray of shape (2, num_points) + Location of the tail nodes. + pole_vec : np.ndarray, optional + The pole vector to rotate the points to. Defaults to the north pole. + + Returns + ------- + np.ndarray of shape (3, num_points) + The direction of the edge after rotating the north pole. + """ + if pole_vec is None: + pole_vec = np.array([0, 0, 1]) + + # all will be rotated relative to destination node + loc1_xyz = latlon_rad_to_cartesian(loc1, 1.0) + loc2_xyz = latlon_rad_to_cartesian(loc2, 1.0) + r = get_rotation_from_unit_vecs(loc2_xyz, pole_vec) + direction = direction_vec(r.apply(loc1_xyz), pole_vec) + return direction / np.sqrt(np.power(direction, 2).sum(axis=0)) + + +def directional_edge_features( + loc1: np.ndarray, loc2: np.ndarray, relative_to_rotated_target: bool = True +) -> np.ndarray: + """Compute features of the edge joining the nodes considered. + + It computes the direction of the edge after rotating the north pole. + + Parameters + ---------- + loc1 : np.ndarray of shpae (2, num_points) + Location of the head node. + loc2 : np.ndarray of shape (2, num_points) + Location of the tail node. + relative_to_rotated_target : bool, optional + Whether to rotate the north pole to the target node. Defaults to True. + + Returns + ------- + np.ndarray of shape of (2, num_points) + Direction of the edge after rotation the north pole. + """ + if relative_to_rotated_target: + rotation = compute_directions(loc1, loc2) + assert np.allclose(rotation[2], 0), "Rotation should be aligned with the north pole" + return rotation[:2] + + return loc2 - loc1 diff --git a/src/anemoi/graphs/generate/__init__.py b/src/anemoi/graphs/generate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/anemoi/graphs/generate/transforms.py b/src/anemoi/graphs/generate/transforms.py new file mode 100644 index 0000000..99e838e --- /dev/null +++ b/src/anemoi/graphs/generate/transforms.py @@ -0,0 +1,95 @@ +import numpy as np + + +def cartesian_to_latlon_degrees(xyz: np.ndarray) -> np.ndarray: + """3D to lat-lon (in degrees) conversion. + + Convert 3D coordinates of points to the (lat, lon) on the sphere containing + them. + + Parameters + ---------- + xyz : np.ndarray + The 3D coordinates of points. + + Returns + ------- + np.ndarray + A 2D array of lat-lon coordinates of shape (N, 2). + """ + lat = np.arcsin(xyz[..., 2] / (xyz**2).sum(axis=1)) * 180.0 / np.pi + lon = np.arctan2(xyz[..., 1], xyz[..., 0]) * 180.0 / np.pi + return np.array((lat, lon), dtype=np.float32).transpose() + + +def cartesian_to_latlon_rad(xyz: np.ndarray) -> np.ndarray: + """3D to lat-lon (in radians) conversion. + + Convert 3D coordinates of points to its coordinates on the sphere containing + them. + + Parameters + ---------- + xyz : np.ndarray + The 3D coordinates of points. + + Returns + ------- + np.ndarray + A 2D array of the coordinates of shape (N, 2) in radians. + """ + lat = np.arcsin(xyz[..., 2] / (xyz**2).sum(axis=1)) + lon = np.arctan2(xyz[..., 1], xyz[..., 0]) + return np.array((lat, lon), dtype=np.float32).transpose() + + +def latlon_rad_to_cartesian(loc: tuple[np.ndarray, np.ndarray], radius: float = 1) -> np.ndarray: + """Convert planar coordinates to 3D coordinates in a sphere. + + Parameters + ---------- + loc : np.ndarray + The 2D coordinates of the points, in radians. + radius : float, optional + The radius of the sphere containing los points. Defaults to the unit sphere. + + Returns + ------- + np.array of shape (3, num_points) + 3D coordinates of the points in the sphere. + """ + latr, lonr = loc[0], loc[1] + x = radius * np.cos(latr) * np.cos(lonr) + y = radius * np.cos(latr) * np.sin(lonr) + z = radius * np.sin(latr) + return np.array((x, y, z)).T + + +def direction_vec(points: np.ndarray, reference: np.ndarray, epsilon: float = 10e-11) -> np.ndarray: + """Direction vector computation. + + Compute the direction vector of a set of points with respect to a reference + vector. + + Parameters + ---------- + points : np.array of shape (num_points, 3) + The points to compute the direction vector. + reference : np.array of shape (3, ) + The reference vector. + epsilon : float, optional + The value to add to the first vector to avoid division by zero. Defaults to 10e-11. + + Returns + ------- + np.array of shape (3, num_points) + The direction vector of the cross product of the two vectors. + """ + v = np.cross(points, reference) + vnorm1 = np.power(v, 2).sum(axis=-1) + redo_idx = np.where(vnorm1 < epsilon)[0] + if len(redo_idx) > 0: + points[redo_idx] += epsilon + v = np.cross(points, reference) + vnorm1 = np.power(v, 2).sum(axis=-1) + return v.T / np.sqrt(vnorm1) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py new file mode 100644 index 0000000..737f27f --- /dev/null +++ b/src/anemoi/graphs/nodes/__init__.py @@ -0,0 +1,4 @@ +from .builder import NPZFileNodes +from .builder import ZarrDatasetNodes + +__all__ = ["ZarrDatasetNodes", "NPZFileNodes"] diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py new file mode 100644 index 0000000..b1942b7 --- /dev/null +++ b/src/anemoi/graphs/nodes/attributes.py @@ -0,0 +1,115 @@ +import logging +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import torch +from scipy.spatial import SphericalVoronoi +from torch_geometric.data.storage import NodeStorage + +from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian +from anemoi.graphs.normalizer import NormalizerMixin + +LOGGER = logging.getLogger(__name__) + + +@dataclass +class BaseWeights(ABC, NormalizerMixin): + """Base class for the weights of the nodes.""" + + norm: Optional[str] = None + + @abstractmethod + def get_raw_values(self, nodes: NodeStorage, *args, **kwargs): ... + + def post_process(self, values: np.ndarray) -> torch.Tensor: + """Post-process the values.""" + if values.ndim == 1: + values = values[:, np.newaxis] + + return torch.tensor(values) + + def compute(self, nodes: NodeStorage, *args, **kwargs) -> torch.Tensor: + """Get the node weights. + + Returns + ------- + torch.Tensor + Weights associated to the nodes. + """ + weights = self.get_raw_values(nodes, *args, **kwargs) + norm_weights = self.normalize(weights) + return self.post_process(norm_weights) + + +class UniformWeights(BaseWeights): + """Implements a uniform weight for the nodes.""" + + def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: + """Compute the weights. + + Parameters + ---------- + nodes : NodeStorage + Nodes of the graph. + + Returns + ------- + np.ndarray + Weights. + """ + return np.ones(nodes.num_nodes) + + +@dataclass +class AreaWeights(BaseWeights): + """Implements the area of the nodes as the weights. + + Attributes + ---------- + norm : str + Normalization of the weights. + radius : float + Radius of the sphere. + centre : np.ndarray + Centre of the sphere. + + Methods + ------- + get_raw_values(nodes, *args, **kwargs) + Compute the area associated to each node. + compute(nodes, *args, **kwargs) + Compute the area attributes for each node. + """ + + norm: Optional[str] = "unit-max" + radius: float = 1.0 + centre: np.ndarray = np.array([0, 0, 0]) + + def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: + """Compute the area associated to each node. + + It uses Voronoi diagrams to compute the area of each node. + + Parameters + ---------- + nodes : NodeStorage + Nodes of the graph. + + Returns + ------- + np.ndarray + Weights. + """ + latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] + points = latlon_rad_to_cartesian((latitudes, longitudes)) + sv = SphericalVoronoi(points, self.radius, self.centre) + area_weights = sv.calculate_areas() + LOGGER.debug( + "There are %d of weights, which (unscaled) add up a total weight of %.2f.", + len(area_weights), + np.array(area_weights).sum(), + ) + return area_weights diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py new file mode 100644 index 0000000..11e99f6 --- /dev/null +++ b/src/anemoi/graphs/nodes/builder.py @@ -0,0 +1,168 @@ +import logging +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +from anemoi.datasets import open_dataset +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + +LOGGER = logging.getLogger(__name__) + + +class BaseNodeBuilder(ABC): + """Base class for node builders.""" + + def register_nodes(self, graph: HeteroData, name: str) -> None: + """Register nodes in the graph. + + Parameters + ---------- + graph : HeteroData + The graph to register the nodes. + name : str + The name of the nodes. + """ + graph[name].x = self.get_coordinates() + graph[name].node_type = type(self).__name__ + return graph + + def register_attributes(self, graph: HeteroData, name: str, config: Optional[DotDict] = None) -> HeteroData: + """Register attributes in the nodes of the graph specified. + + Parameters + ---------- + graph : HeteroData + The graph to register the attributes. + name : str + The name of the nodes. + config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with the registered attributes. + """ + if config is None: + return graph + + for attr_name, attr_config in config.items(): + graph[name][attr_name] = instantiate(attr_config).compute(graph[name]) + return graph + + @abstractmethod + def get_coordinates(self) -> torch.Tensor: ... + + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch.Tensor: + """Reshape latitude and longitude coordinates. + + Parameters + ---------- + latitudes : np.ndarray of shape (N, ) + Latitude coordinates, in degrees. + longitudes : np.ndarray of shape (N, ) + Longitude coordinates, in degrees. + + Returns + ------- + torch.Tensor of shape (N, 2) + A 2D tensor with the coordinates, in radians. + """ + coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) + coords = np.deg2rad(coords) + return torch.tensor(coords, dtype=torch.float32) + + def update_graph(self, graph: HeteroData, name: str, attr_config: Optional[DotDict] = None) -> HeteroData: + """Update the graph with new nodes. + + Parameters + ---------- + graph : HeteroData + Input graph. + name : str + The name of the nodes. + attr_config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with new nodes included. + """ + graph = self.register_nodes(graph, name) + + if attr_config is None: + return graph + + graph = self.register_attributes(graph, name, attr_config) + return graph + + +class ZarrDatasetNodes(BaseNodeBuilder): + """Nodes from Zarr dataset. + + Attributes + ---------- + ds : zarr.core.Array + The dataset. + """ + + def __init__(self, dataset: DotDict) -> None: + LOGGER.info("Reading the dataset from %s.", dataset) + self.ds = open_dataset(dataset) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (N, 2) + Coordinates of the nodes. + """ + return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) + + +class NPZFileNodes(BaseNodeBuilder): + """Nodes from NPZ defined grids. + + Attributes + ---------- + resolution : str + The resolution of the grid. + grid_definition_path : str + Path to the folder containing the grid definition files. + grid_definition : dict[str, np.ndarray] + The grid definition. + """ + + def __init__(self, resolution: str, grid_definition_path: str) -> None: + """Initialize the NPZFileNodes builder. + + The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. + + Parameters + ---------- + resolution : str + The resolution of the grid. + grid_definition_path : str + Path to the folder containing the grid definition files. + """ + self.resolution = resolution + self.grid_definition_path = grid_definition_path + self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (N, 2) + Coordinates of the nodes. + """ + coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) + return coords diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py new file mode 100644 index 0000000..9c261e5 --- /dev/null +++ b/src/anemoi/graphs/normalizer.py @@ -0,0 +1,44 @@ +import logging + +import numpy as np + +LOGGER = logging.getLogger(__name__) + + +class NormalizerMixin: + """Mixin class for normalizing attributes.""" + + def normalize(self, values: np.ndarray) -> np.ndarray: + """Normalize the given values. + + It supports different normalization methods: None, 'l1', + 'l2', 'unit-max' and 'unit-std'. + + Parameters + ---------- + values : np.ndarray + Values to normalize. + + Returns + ------- + np.ndarray + Normalized values. + """ + if self.norm is None: + LOGGER.debug("Node weights are not normalized.") + return values + if self.norm == "l1": + return values / np.sum(values) + if self.norm == "l2": + return values / np.linalg.norm(values) + if self.norm == "unit-max": + return values / np.amax(values) + if self.norm == "unit-std": + std = np.std(values) + if std == 0: + LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} is 0. Cannot normalize.") + return values + return values / std + raise ValueError( + f"Weight normalization \"{self.norm}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'." + ) diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py new file mode 100644 index 0000000..8999bc6 --- /dev/null +++ b/src/anemoi/graphs/utils.py @@ -0,0 +1,133 @@ +from typing import Optional + +import numpy as np +import torch +from sklearn.neighbors import NearestNeighbors + + +def get_nearest_neighbour(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] = None) -> NearestNeighbors: + """Get NearestNeighbour object fitted to coordinates. + + Parameters + ---------- + coords_rad : torch.Tensor + corrdinates in radians + mask : Optional[torch.Tensor], optional + mask to remove nodes, by default None + + Returns + ------- + NearestNeighbors + fitted NearestNeighbour object + """ + assert mask is None or mask.shape == ( + coords_rad.shape[0], + 1, + ), "Mask must have the same shape as the number of nodes." + + nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + + nearest_neighbour.fit(coords_rad) + + return nearest_neighbour + + +def get_grid_reference_distance(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] = None) -> float: + """Get the reference distance of the grid. + + It is the maximum distance of a node in the mesh with respect to its nearest neighbour. + + Parameters + ---------- + coords_rad : torch.Tensor + corrdinates in radians + mask : Optional[torch.Tensor], optional + mask to remove nodes, by default None + + Returns + ------- + float + The reference distance of the grid. + """ + nearest_neighbours = get_nearest_neighbour(coords_rad, mask) + dists, _ = nearest_neighbours.kneighbors(coords_rad, n_neighbors=2, return_distance=True) + return dists[dists > 0].max() + + +def add_margin(lats: np.ndarray, lons: np.ndarray, margin: float) -> tuple[np.ndarray, np.ndarray]: + """Add a margin to the convex hull of the points considered. + + For each point (lat, lon) add 8 points around it, each at a distance of `margin` from the original point. + + Arguments + --------- + lats : np.ndarray + Latitudes of the points considered. + lons : np.ndarray + Longitudes of the points considered. + margin : float + The margin to add to the convex hull. + + Returns + ------- + latitudes : np.ndarray + Latitudes of the points considered, including the margin. + longitudes : np.ndarray + Longitudes of the points considered, including the margin. + """ + assert margin >= 0, "Margin must be non-negative" + if margin == 0: + return lats, lons + + latitudes, longitudes = [], [] + for lat_sign in [-1, 0, 1]: + for lon_sign in [-1, 0, 1]: + latitudes.append(lats + lat_sign * margin) + longitudes.append(lons + lon_sign * margin) + + return np.concatenate(latitudes), np.concatenate(longitudes) + + +def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int: + """Index position of vector. + + Get the index position of a vector in a matrix. + + Parameters + ---------- + vector : torch.Tensor of shape (N, ) + Vector to get its position in the matrix. + tensor : torch.Tensor of shape (M, N,) + Tensor in which the position is searched. + + Returns + ------- + int + Index position of `vector` in `tensor`. -1 if `vector` is not in `tensor`. + """ + mask = torch.all(tensor == vector, axis=1) + if mask.any(): + return int(torch.where(mask)[0]) + return -1 + + +def haversine_distance(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray: + """Haversine distance. + + Parameters + ---------- + source_coords : np.ndarray of shape (N, 2) + Source coordinates in radians. + target_coords : np.ndarray of shape (N, 2) + Destination coordinates in radians. + + Returns + ------- + np.ndarray of shape (N,) + Haversine distance between source and destination coordinates. + """ + dlat = target_coords[:, 0] - source_coords[:, 0] + dlon = target_coords[:, 1] - source_coords[:, 1] + a = np.sin(dlat / 2) ** 2 + np.cos(source_coords[:, 0]) * np.cos(target_coords[:, 0]) * np.sin(dlon / 2) ** 2 + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) + return c diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..290165c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,92 @@ +import numpy as np +import pytest +import torch +import yaml +from torch_geometric.data import HeteroData + +lats = [-0.15, 0, 0.15] +lons = [0, 0.25, 0.5, 0.75] + + +class MockZarrDataset: + """Mock Zarr dataset with latitudes and longitudes attributes.""" + + def __init__(self, latitudes, longitudes): + self.latitudes = latitudes + self.longitudes = longitudes + self.num_nodes = len(latitudes) + + +@pytest.fixture +def mock_zarr_dataset() -> MockZarrDataset: + """Mock zarr dataset with nodes.""" + coords = 2 * torch.pi * np.array([[lat, lon] for lat in lats for lon in lons]) + return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1]) + + +@pytest.fixture +def mock_grids_path(tmp_path) -> tuple[str, int]: + """Mock grid_definition_path with files for 3 resolutions.""" + num_nodes = len(lats) * len(lons) + for resolution in ["o16", "o48", "5km5"]: + file_path = tmp_path / f"grid-{resolution}.npz" + np.savez(file_path, latitudes=np.random.rand(num_nodes), longitudes=np.random.rand(num_nodes)) + return str(tmp_path), num_nodes + + +@pytest.fixture +def graph_with_nodes() -> HeteroData: + """Graph with 12 nodes over the globe, stored in \"test_nodes\".""" + coords = np.array([[lat, lon] for lat in lats for lon in lons]) + graph = HeteroData() + graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + return graph + + +@pytest.fixture +def graph_nodes_and_edges() -> HeteroData: + """Graph with 1 set of nodes and edges.""" + coords = np.array([[lat, lon] for lat in lats for lon in lons]) + graph = HeteroData() + graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) + return graph + + +@pytest.fixture +def config_file(tmp_path) -> tuple[str, str]: + """Mock grid_definition_path with files for 3 resolutions.""" + cfg = { + "nodes": { + "test_nodes": { + "node_builder": { + "_target_": "anemoi.graphs.nodes.NPZFileNodes", + "grid_definition_path": str(tmp_path), + "resolution": "o16", + }, + } + }, + "edges": [ + { + "nodes": {"source_name": "test_nodes", "target_name": "test_nodes"}, + "edge_builder": { + "_target_": "anemoi.graphs.edges.KNNEdges", + "num_nearest_neighbours": 3, + }, + "attributes": { + "dist_norm": { + "_target_": "anemoi.graphs.edges.attributes.EdgeLength", + "norm": "l1", + "invert": True, + }, + "edge_dirs": {"_target_": "anemoi.graphs.edges.attributes.EdgeDirection"}, + }, + }, + ], + } + file_name = "config.yaml" + + with (tmp_path / file_name).open("w") as file: + yaml.dump(cfg, file) + + return tmp_path, file_name diff --git a/tests/edges/test_cutoff.py b/tests/edges/test_cutoff.py new file mode 100644 index 0000000..838134c --- /dev/null +++ b/tests/edges/test_cutoff.py @@ -0,0 +1,22 @@ +import pytest + +from anemoi.graphs.edges import CutOffEdges + + +def test_init(): + """Test CutOffEdges initialization.""" + CutOffEdges("test_nodes1", "test_nodes2", 0.5) + + +@pytest.mark.parametrize("cutoff_factor", [-0.5, "hello", None]) +def test_fail_init(cutoff_factor: str): + """Test CutOffEdges initialization with invalid cutoff.""" + with pytest.raises(AssertionError): + CutOffEdges("test_nodes1", "test_nodes2", cutoff_factor) + + +def test_cutoff(graph_with_nodes): + """Test CutOffEdges.""" + builder = CutOffEdges("test_nodes", "test_nodes", 0.5) + graph = builder.update_graph(graph_with_nodes) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types diff --git a/tests/edges/test_edge_attributes.py b/tests/edges/test_edge_attributes.py new file mode 100644 index 0000000..b0bbede --- /dev/null +++ b/tests/edges/test_edge_attributes.py @@ -0,0 +1,29 @@ +import pytest +import torch + +from anemoi.graphs.edges.attributes import EdgeDirection +from anemoi.graphs.edges.attributes import EdgeLength + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-std"]) +@pytest.mark.parametrize("luse_rotated_features", [True, False]) +def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features: bool): + """Test EdgeDirection compute method.""" + edge_attr_builder = EdgeDirection(norm=norm, luse_rotated_features=luse_rotated_features) + edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, "test_nodes", "test_nodes") + assert isinstance(edge_attr, torch.Tensor) + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-std"]) +def test_edge_lengths(graph_nodes_and_edges, norm): + """Test EdgeLength compute method.""" + edge_attr_builder = EdgeLength(norm=norm) + edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, "test_nodes", "test_nodes") + assert isinstance(edge_attr, torch.Tensor) + + +@pytest.mark.parametrize("attribute_builder", [EdgeDirection(), EdgeLength()]) +def test_fail_edge_features(attribute_builder, graph_nodes_and_edges): + """Test EdgeDirection compute method.""" + with pytest.raises(AssertionError): + attribute_builder.compute(graph_nodes_and_edges, "test_nodes", "unknown_nodes") diff --git a/tests/edges/test_knn.py b/tests/edges/test_knn.py new file mode 100644 index 0000000..9f6cae9 --- /dev/null +++ b/tests/edges/test_knn.py @@ -0,0 +1,22 @@ +import pytest + +from anemoi.graphs.edges import KNNEdges + + +def test_init(): + """Test KNNEdges initialization.""" + KNNEdges("test_nodes1", "test_nodes2", 3) + + +@pytest.mark.parametrize("num_nearest_neighbours", [-1, 2.6, "hello", None]) +def test_fail_init(num_nearest_neighbours: str): + """Test KNNEdges initialization with invalid number of nearest neighbours.""" + with pytest.raises(AssertionError): + KNNEdges("test_nodes1", "test_nodes2", num_nearest_neighbours) + + +def test_knn(graph_with_nodes): + """Test KNNEdges.""" + builder = KNNEdges("test_nodes", "test_nodes", 3) + graph = builder.update_graph(graph_with_nodes) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types diff --git a/tests/nodes/test_node_attributes.py b/tests/nodes/test_node_attributes.py new file mode 100644 index 0000000..3d7e5be --- /dev/null +++ b/tests/nodes/test_node_attributes.py @@ -0,0 +1,42 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights + + +@pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-std"]) +def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): + """Test NPZNodes register correctly the weights.""" + node_attr_builder = UniformWeights(norm=norm) + weights = node_attr_builder.compute(graph_with_nodes["test_nodes"]) + + assert weights is not None + assert isinstance(weights, torch.Tensor) + assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] + + +@pytest.mark.parametrize("norm", ["l3", "invalide"]) +def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): + """Test NPZNodes register correctly the weights.""" + with pytest.raises(ValueError): + node_attr_builder = UniformWeights(norm=norm) + node_attr_builder.compute(graph_with_nodes["test_nodes"]) + + +def test_area_weights(graph_with_nodes: HeteroData): + """Test NPZNodes register correctly the weights.""" + node_attr_builder = AreaWeights() + weights = node_attr_builder.compute(graph_with_nodes["test_nodes"]) + + assert weights is not None + assert isinstance(weights, torch.Tensor) + assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] + + +@pytest.mark.parametrize("radius", [-1.0, "hello", None]) +def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): + with pytest.raises(ValueError): + node_attr_builder = AreaWeights(radius=radius) + node_attr_builder.compute(graph_with_nodes["test_nodes"]) diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py new file mode 100644 index 0000000..fc4cf8c --- /dev/null +++ b/tests/nodes/test_npz.py @@ -0,0 +1,58 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builder import NPZFileNodes + + +@pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) +def test_init(mock_grids_path: tuple[str, int], resolution: str): + """Test NPZNodes initialization.""" + grid_definition_path, _ = mock_grids_path + node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + assert isinstance(node_builder, NPZFileNodes) + + +@pytest.mark.parametrize("resolution", ["o17", 13, "ajsnb", None]) +def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution: str): + """Test NPZNodes initialization with invalid resolution.""" + grid_definition_path, _ = mock_grids_path + with pytest.raises(FileNotFoundError): + NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + + +def test_fail_init_wrong_path(): + """Test NPZNodes initialization with invalid path.""" + with pytest.raises(FileNotFoundError): + NPZFileNodes("o16", "invalid_path") + + +@pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) +def test_register_nodes(mock_grids_path: str, resolution: str): + """Test NPZNodes register correctly the nodes.""" + graph = HeteroData() + grid_definition_path, num_nodes = mock_grids_path + node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + + graph = node_builder.register_nodes(graph, "test_nodes") + + assert graph["test_nodes"].x is not None + assert isinstance(graph["test_nodes"].x, torch.Tensor) + assert graph["test_nodes"].x.shape == (num_nodes, 2) + assert graph["test_nodes"].node_type == "NPZFileNodes" + + +@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) +def test_register_attributes(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): + """Test NPZNodes register correctly the weights.""" + grid_definition_path, _ = mock_grids_path + node_builder = NPZFileNodes("o16", grid_definition_path=grid_definition_path) + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} + + graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) + + assert graph["test_nodes"]["test_attr"] is not None + assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) + assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py new file mode 100644 index 0000000..0e91ece --- /dev/null +++ b/tests/nodes/test_zarr.py @@ -0,0 +1,50 @@ +import pytest +import torch +import zarr +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes import builder +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights + + +def test_init(mocker, mock_zarr_dataset): + """Test ZarrNodes initialization.""" + mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) + node_builder = builder.ZarrDatasetNodes("dataset.zarr") + assert isinstance(node_builder, builder.BaseNodeBuilder) + assert isinstance(node_builder, builder.ZarrDatasetNodes) + + +def test_fail_init(): + """Test ZarrNodes initialization with invalid resolution.""" + with pytest.raises(zarr.errors.PathNotFoundError): + builder.ZarrDatasetNodes("invalid_path.zarr") + + +def test_register_nodes(mocker, mock_zarr_dataset): + """Test ZarrNodes register correctly the nodes.""" + mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) + node_builder = builder.ZarrDatasetNodes("dataset.zarr") + graph = HeteroData() + + graph = node_builder.register_nodes(graph, "test_nodes") + + assert graph["test_nodes"].x is not None + assert isinstance(graph["test_nodes"].x, torch.Tensor) + assert graph["test_nodes"].x.shape == (node_builder.ds.num_nodes, 2) + assert graph["test_nodes"].node_type == "ZarrDatasetNodes" + + +@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) +def test_register_attributes(mocker, graph_with_nodes: HeteroData, attr_class): + """Test ZarrNodes register correctly the weights.""" + mocker.patch.object(builder, "open_dataset", return_value=None) + node_builder = builder.ZarrDatasetNodes("dataset.zarr") + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} + + graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) + + assert graph["test_nodes"]["test_attr"] is not None + assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) + assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 846ee89..50731f7 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -6,9 +6,31 @@ # nor does it submit to any jurisdiction. -def test_graphs(): - pass +from pathlib import Path +import torch +from torch_geometric.data import HeteroData -if __name__ == "__main__": - test_graphs() +from anemoi.graphs import create + + +def test_graphs(config_file: tuple[Path, str], mock_grids_path: tuple[str, int]): + """Test GraphCreator workflow.""" + tmp_path, config_name = config_file + graph_path = tmp_path / "graph.pt" + config_path = tmp_path / config_name + + create.GraphCreator(graph_path, config_path).create() + + graph = torch.load(graph_path) + assert isinstance(graph, HeteroData) + assert "test_nodes" in graph.node_types + assert ("test_nodes", "to", "test_nodes") in graph.edge_types + + for nodes in graph.node_stores: + for node_attr in nodes.node_attrs(): + assert isinstance(nodes[node_attr], torch.Tensor) + + for edges in graph.edge_stores: + for edge_attr in edges.edge_attrs(): + assert isinstance(edges[edge_attr], torch.Tensor) diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py new file mode 100644 index 0000000..c63acce --- /dev/null +++ b/tests/test_normalizer.py @@ -0,0 +1,55 @@ +import numpy as np +import pytest + +from anemoi.graphs.normalizer import NormalizerMixin + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-std"]) +def test_normalizer(norm: str): + """Test NormalizerMixin normalize method.""" + + class Normalizer(NormalizerMixin): + def __init__(self, norm): + self.norm = norm + + def __call__(self, data): + return self.normalize(data) + + normalizer = Normalizer(norm=norm) + data = np.random.rand(10, 5) + normalized_data = normalizer(data) + assert isinstance(normalized_data, np.ndarray) + assert normalized_data.shape == data.shape + + +@pytest.mark.parametrize("norm", ["l3", "invalid"]) +def test_normalizer_wrong_norm(norm: str): + """Test NormalizerMixin normalize method.""" + + class Normalizer(NormalizerMixin): + def __init__(self, norm: str): + self.norm = norm + + def __call__(self, data): + return self.normalize(data) + + with pytest.raises(ValueError): + normalizer = Normalizer(norm=norm) + data = np.random.rand(10, 5) + normalizer(data) + + +def test_normalizer_wrong_inheritance(): + """Test NormalizerMixin normalize method.""" + + class Normalizer(NormalizerMixin): + def __init__(self, attr): + self.attr = attr + + def __call__(self, data): + return self.normalize(data) + + with pytest.raises(AttributeError): + normalizer = Normalizer(attr="attr_name") + data = np.random.rand(10, 5) + normalizer(data) From a973c2d039f1fbe5a9655f0ffb54d7e673fd0365 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Sat, 6 Jul 2024 11:22:15 +0000 Subject: [PATCH 055/156] fix: attributes as torch.float32 --- src/anemoi/graphs/edges/attributes.py | 5 +++-- src/anemoi/graphs/nodes/attributes.py | 7 ++++--- src/anemoi/graphs/normalizer.py | 8 ++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 6945867..99c9669 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -29,7 +29,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: if values.ndim == 1: values = values[:, np.newaxis] - return torch.tensor(values) + normed_values = self.normalize(values) + + return torch.tensor(normed_values, dtype=torch.float32) def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> torch.Tensor: """Compute the edge attributes.""" @@ -41,7 +43,6 @@ def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, ), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs) - normed_values = self.normalize(values) return self.post_process(normed_values) diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index b1942b7..0bd4d2e 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -29,7 +29,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: if values.ndim == 1: values = values[:, np.newaxis] - return torch.tensor(values) + norm_values = self.normalize(values) + + return torch.tensor(norm_values, dtype=torch.float32) def compute(self, nodes: NodeStorage, *args, **kwargs) -> torch.Tensor: """Get the node weights. @@ -40,8 +42,7 @@ def compute(self, nodes: NodeStorage, *args, **kwargs) -> torch.Tensor: Weights associated to the nodes. """ weights = self.get_raw_values(nodes, *args, **kwargs) - norm_weights = self.normalize(weights) - return self.post_process(norm_weights) + return self.post_process(weights) class UniformWeights(BaseWeights): diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index 9c261e5..c625417 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -16,7 +16,7 @@ def normalize(self, values: np.ndarray) -> np.ndarray: Parameters ---------- - values : np.ndarray + values : np.ndarray of shape (N, M) Values to normalize. Returns @@ -25,7 +25,7 @@ def normalize(self, values: np.ndarray) -> np.ndarray: Normalized values. """ if self.norm is None: - LOGGER.debug("Node weights are not normalized.") + LOGGER.debug(f"{self.__class__.__name__} values are not normalized.") return values if self.norm == "l1": return values / np.sum(values) @@ -36,9 +36,9 @@ def normalize(self, values: np.ndarray) -> np.ndarray: if self.norm == "unit-std": std = np.std(values) if std == 0: - LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} is 0. Cannot normalize.") + LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} values is 0. Normalization is skipped.") return values return values / std raise ValueError( - f"Weight normalization \"{self.norm}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'." + f"Attribute normalization \"{self.norm}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'." ) From af111a6f406897fc46bac99735af13c53634f1bb Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Sat, 6 Jul 2024 11:22:41 +0000 Subject: [PATCH 056/156] new test: attributes must be float32 --- tests/test_graphs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 50731f7..ba2704f 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -30,7 +30,9 @@ def test_graphs(config_file: tuple[Path, str], mock_grids_path: tuple[str, int]) for nodes in graph.node_stores: for node_attr in nodes.node_attrs(): assert isinstance(nodes[node_attr], torch.Tensor) + assert nodes[node_attr].dtype in [torch.int32, torch.float32] for edges in graph.edge_stores: for edge_attr in edges.edge_attrs(): assert isinstance(edges[edge_attr], torch.Tensor) + assert edges[edge_attr].dtype in [torch.int32, torch.float32] From a8a162056f2e0dd1f90d3e8dd3312169d752bd11 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Sat, 6 Jul 2024 11:36:08 +0000 Subject: [PATCH 057/156] fix typo --- src/anemoi/graphs/edges/attributes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 99c9669..cf08e47 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -43,7 +43,7 @@ def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, ), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs) - return self.post_process(normed_values) + return self.post_process(values) @dataclass From 926c75b1486bc8a6960c988ac8edaff032a4ad52 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Sat, 6 Jul 2024 11:38:16 +0000 Subject: [PATCH 058/156] Homogeneize base builders --- src/anemoi/graphs/edges/builder.py | 73 +++++++++++++++++++++++------- src/anemoi/graphs/nodes/builder.py | 22 +++++++++ 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 3c926d6..2cb048d 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -18,6 +18,7 @@ class BaseEdgeBuilder(ABC): + """Base class for edge builders.""" def __init__(self, source_name: str, target_name: str): super().__init__() @@ -27,25 +28,47 @@ def __init__(self, source_name: str, target_name: str): @abstractmethod def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ... - def register_edges(self, graph: HeteroData, source_indices: np.ndarray, target_indices: np.ndarray) -> HeteroData: + def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: + """Prepare nodes information.""" + return graph[self.source_name], graph[self.target_name] + + def get_edge_index(self, graph: HeteroData) -> torch.Tensor: + """Get the edge indices of source and target nodes. + + Parameters + ---------- + graph : HeteroData + The graph. + + Returns + ------- + torch.Tensor of shape (2, num_edges) + The edge indices. + """ + source_nodes, target_nodes = self.prepare_node_data(graph) + + adjmat = self.get_adjacency_matrix(source_nodes, target_nodes) + + # Get source & target indices of the edges + edge_index = np.stack([adjmat.col, adjmat.row], axis=0) + + return torch.from_numpy(edge_index, dtype=torch.int32) + + def register_edges(self, graph: HeteroData) -> HeteroData: """Register edges in the graph. Parameters ---------- graph : HeteroData The graph to register the edges. - source_indices : np.ndarray of shape (N, ) - The indices of the source nodes. - target_indices : np.ndarray of shape (N, ) - The indices of the target nodes. Returns ------- HeteroData The graph with the registered edges. """ - edge_index = np.stack([source_indices, target_indices], axis=0).astype(np.int32) - graph[(self.source_name, "to", self.target_name)].edge_index = torch.from_numpy(edge_index) + graph[(self.source_name, "to", self.target_name)].edge_index = self.get_edge_index() + graph[(self.source_name, "to", self.target_name)].edge_type = type(self).__name__ return graph def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: @@ -69,10 +92,6 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: ) return graph - def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: - """Prepare nodes information.""" - return graph[self.source_name], graph[self.target_name] - def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: """Update the graph with the edges. @@ -88,11 +107,9 @@ def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None HeteroData The graph with the edges. """ - source_nodes, target_nodes = self.prepare_node_data(graph) - - adjmat = self.get_adjacency_matrix(source_nodes, target_nodes) - - graph = self.register_edges(graph, adjmat.col, adjmat.row) + graph = self.register_edges( + graph, + ) if attrs_config is None: return graph @@ -113,6 +130,17 @@ class KNNEdges(BaseEdgeBuilder): The name of the target nodes. num_nearest_neighbours : int Number of nearest neighbours. + + Methods + ------- + get_adjacency_matrix(source_nodes, target_nodes) + Compute the adjacency matrix for the KNN method. + register_edges(graph) + Register the edges in the graph. + register_attributes(graph, config) + Register attributes in the edges of the graph. + update_graph(graph, attrs_config) + Update the graph with the edges. """ def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): @@ -162,6 +190,19 @@ class CutOffEdges(BaseEdgeBuilder): Factor to multiply the grid reference distance to get the cut-off radius. radius : float Cut-off radius. + + Methods + ------- + get_cutoff_radius(graph, mask_attr) + Compute the cut-off radius. + get_adjacency_matrix(source_nodes, target_nodes) + Get the adjacency matrix for the cut-off method. + register_edges(graph) + Register the edges in the graph. + register_attributes(graph, config) + Register attributes in the edges of the graph. + update_graph(graph, attrs_config) + Update the graph with the edges. """ def __init__(self, source_name: str, target_name: str, cutoff_factor: float): diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 11e99f6..4bd675b 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -110,6 +110,17 @@ class ZarrDatasetNodes(BaseNodeBuilder): ---------- ds : zarr.core.Array The dataset. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. """ def __init__(self, dataset: DotDict) -> None: @@ -138,6 +149,17 @@ class NPZFileNodes(BaseNodeBuilder): Path to the folder containing the grid definition files. grid_definition : dict[str, np.ndarray] The grid definition. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. """ def __init__(self, resolution: str, grid_definition_path: str) -> None: From f34207333e1ca0ba2726039b8fc2218015ca0dd1 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Sat, 6 Jul 2024 13:50:36 +0000 Subject: [PATCH 059/156] improve test docstrings --- tests/edges/test_edge_attributes.py | 2 +- tests/nodes/test_node_attributes.py | 7 ++++--- tests/nodes/test_npz.py | 10 +++++----- tests/nodes/test_zarr.py | 8 ++++---- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/edges/test_edge_attributes.py b/tests/edges/test_edge_attributes.py index b0bbede..5c49119 100644 --- a/tests/edges/test_edge_attributes.py +++ b/tests/edges/test_edge_attributes.py @@ -24,6 +24,6 @@ def test_edge_lengths(graph_nodes_and_edges, norm): @pytest.mark.parametrize("attribute_builder", [EdgeDirection(), EdgeLength()]) def test_fail_edge_features(attribute_builder, graph_nodes_and_edges): - """Test EdgeDirection compute method.""" + """Test edge attribute builder fails with unknown nodes.""" with pytest.raises(AssertionError): attribute_builder.compute(graph_nodes_and_edges, "test_nodes", "unknown_nodes") diff --git a/tests/nodes/test_node_attributes.py b/tests/nodes/test_node_attributes.py index 3d7e5be..300585b 100644 --- a/tests/nodes/test_node_attributes.py +++ b/tests/nodes/test_node_attributes.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-std"]) def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): - """Test NPZNodes register correctly the weights.""" + """Test attribute builder for UniformWeights.""" node_attr_builder = UniformWeights(norm=norm) weights = node_attr_builder.compute(graph_with_nodes["test_nodes"]) @@ -19,14 +19,14 @@ def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): @pytest.mark.parametrize("norm", ["l3", "invalide"]) def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): - """Test NPZNodes register correctly the weights.""" + """Test attribute builder for UniformWeights with invalid norm.""" with pytest.raises(ValueError): node_attr_builder = UniformWeights(norm=norm) node_attr_builder.compute(graph_with_nodes["test_nodes"]) def test_area_weights(graph_with_nodes: HeteroData): - """Test NPZNodes register correctly the weights.""" + """Test attribute builder for AreaWeights.""" node_attr_builder = AreaWeights() weights = node_attr_builder.compute(graph_with_nodes["test_nodes"]) @@ -37,6 +37,7 @@ def test_area_weights(graph_with_nodes: HeteroData): @pytest.mark.parametrize("radius", [-1.0, "hello", None]) def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): + """Test attribute builder for AreaWeights with invalid radius.""" with pytest.raises(ValueError): node_attr_builder = AreaWeights(radius=radius) node_attr_builder.compute(graph_with_nodes["test_nodes"]) diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index fc4cf8c..02febca 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) def test_init(mock_grids_path: tuple[str, int], resolution: str): - """Test NPZNodes initialization.""" + """Test NPZFileNodes initialization.""" grid_definition_path, _ = mock_grids_path node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) assert isinstance(node_builder, NPZFileNodes) @@ -17,21 +17,21 @@ def test_init(mock_grids_path: tuple[str, int], resolution: str): @pytest.mark.parametrize("resolution", ["o17", 13, "ajsnb", None]) def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution: str): - """Test NPZNodes initialization with invalid resolution.""" + """Test NPZFileNodes initialization with invalid resolution.""" grid_definition_path, _ = mock_grids_path with pytest.raises(FileNotFoundError): NPZFileNodes(resolution, grid_definition_path=grid_definition_path) def test_fail_init_wrong_path(): - """Test NPZNodes initialization with invalid path.""" + """Test NPZFileNodes initialization with invalid path.""" with pytest.raises(FileNotFoundError): NPZFileNodes("o16", "invalid_path") @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) def test_register_nodes(mock_grids_path: str, resolution: str): - """Test NPZNodes register correctly the nodes.""" + """Test NPZFileNodes register correctly the nodes.""" graph = HeteroData() grid_definition_path, num_nodes = mock_grids_path node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) @@ -46,7 +46,7 @@ def test_register_nodes(mock_grids_path: str, resolution: str): @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_attributes(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): - """Test NPZNodes register correctly the weights.""" + """Test NPZFileNodes register correctly the weights.""" grid_definition_path, _ = mock_grids_path node_builder = NPZFileNodes("o16", grid_definition_path=grid_definition_path) config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index 0e91ece..357bc06 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -9,7 +9,7 @@ def test_init(mocker, mock_zarr_dataset): - """Test ZarrNodes initialization.""" + """Test ZarrDatasetNodes initialization.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) node_builder = builder.ZarrDatasetNodes("dataset.zarr") assert isinstance(node_builder, builder.BaseNodeBuilder) @@ -17,13 +17,13 @@ def test_init(mocker, mock_zarr_dataset): def test_fail_init(): - """Test ZarrNodes initialization with invalid resolution.""" + """Test ZarrDatasetNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): builder.ZarrDatasetNodes("invalid_path.zarr") def test_register_nodes(mocker, mock_zarr_dataset): - """Test ZarrNodes register correctly the nodes.""" + """Test ZarrDatasetNodes register correctly the nodes.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) node_builder = builder.ZarrDatasetNodes("dataset.zarr") graph = HeteroData() @@ -38,7 +38,7 @@ def test_register_nodes(mocker, mock_zarr_dataset): @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_attributes(mocker, graph_with_nodes: HeteroData, attr_class): - """Test ZarrNodes register correctly the weights.""" + """Test ZarrDatasetNodes register correctly the weights.""" mocker.patch.object(builder, "open_dataset", return_value=None) node_builder = builder.ZarrDatasetNodes("dataset.zarr") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} From 9d9fea85271276757396b4908e2689cdb4238095 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Sat, 6 Jul 2024 14:12:44 +0000 Subject: [PATCH 060/156] homogeneize (name as class attribute) --- src/anemoi/graphs/create.py | 8 +++-- src/anemoi/graphs/edges/builder.py | 50 ++++++++++++++---------------- src/anemoi/graphs/nodes/builder.py | 36 ++++++++++----------- tests/conftest.py | 15 ++++----- tests/nodes/test_npz.py | 14 ++++----- tests/nodes/test_zarr.py | 12 +++---- 6 files changed, 65 insertions(+), 70 deletions(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 0b09649..8935c96 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -46,11 +46,13 @@ def generate_graph(self) -> HeteroData: HeteroData: The generated graph. """ graph = HeteroData() - for name, nodes_cfg in self.config.nodes.items(): - graph = instantiate(nodes_cfg.node_builder).update_graph(graph, name, nodes_cfg.get("attributes", {})) + for nodes_cfg in self.config.nodes: + graph = instantiate(nodes_cfg.node_builder, name=nodes_cfg.name).update_graph( + graph, nodes_cfg.get("attributes", {}) + ) for edges_cfg in self.config.edges: - graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).update_graph( + graph = instantiate(edges_cfg.edge_builder, **edges_cfg.names).update_graph( graph, edges_cfg.get("attributes", {}) ) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 2cb048d..f9b517d 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -1,6 +1,7 @@ import logging from abc import ABC from abc import abstractmethod +from dataclasses import dataclass from typing import Optional import numpy as np @@ -17,20 +18,19 @@ LOGGER = logging.getLogger(__name__) +@dataclass class BaseEdgeBuilder(ABC): """Base class for edge builders.""" - def __init__(self, source_name: str, target_name: str): - super().__init__() - self.source_name = source_name - self.target_name = target_name + source_nodes: str + target_nodes: str @abstractmethod def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ... def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: """Prepare nodes information.""" - return graph[self.source_name], graph[self.target_name] + return graph[self.source_nodes], graph[self.target_nodes] def get_edge_index(self, graph: HeteroData) -> torch.Tensor: """Get the edge indices of source and target nodes. @@ -52,7 +52,7 @@ def get_edge_index(self, graph: HeteroData) -> torch.Tensor: # Get source & target indices of the edges edge_index = np.stack([adjmat.col, adjmat.row], axis=0) - return torch.from_numpy(edge_index, dtype=torch.int32) + return torch.from_numpy(edge_index).to(torch.float32) def register_edges(self, graph: HeteroData) -> HeteroData: """Register edges in the graph. @@ -67,8 +67,8 @@ def register_edges(self, graph: HeteroData) -> HeteroData: HeteroData The graph with the registered edges. """ - graph[(self.source_name, "to", self.target_name)].edge_index = self.get_edge_index() - graph[(self.source_name, "to", self.target_name)].edge_type = type(self).__name__ + graph[(self.source_nodes, "to", self.target_nodes)].edge_index = self.get_edge_index(graph) + graph[(self.source_nodes, "to", self.target_nodes)].edge_type = type(self).__name__ return graph def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: @@ -87,8 +87,8 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: The graph with the registered attributes. """ for attr_name, attr_config in config.items(): - graph[self.source_name, "to", self.target_name][attr_name] = instantiate(attr_config).compute( - graph, self.source_name, self.target_name + graph[self.source_nodes, "to", self.target_nodes][attr_name] = instantiate(attr_config).compute( + graph, self.source_nodes, self.target_nodes ) return graph @@ -107,9 +107,7 @@ def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None HeteroData The graph with the edges. """ - graph = self.register_edges( - graph, - ) + graph = self.register_edges(graph) if attrs_config is None: return graph @@ -124,9 +122,9 @@ class KNNEdges(BaseEdgeBuilder): Attributes ---------- - source_name : str + source_nodes : str The name of the source nodes. - target_name : str + target_nodes : str The name of the target nodes. num_nearest_neighbours : int Number of nearest neighbours. @@ -143,8 +141,8 @@ class KNNEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): - super().__init__(source_name, target_name) + def __init__(self, source_nodes: str, target_nodes: str, num_nearest_neighbours: int): + super().__init__(source_nodes, target_nodes) assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" self.num_nearest_neighbours = num_nearest_neighbours @@ -163,8 +161,8 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra LOGGER.info( "Using KNN-Edges (with %d nearest neighbours) between %s and %s.", self.num_nearest_neighbours, - self.source_name, - self.target_name, + self.source_nodes, + self.target_nodes, ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) @@ -182,9 +180,9 @@ class CutOffEdges(BaseEdgeBuilder): Attributes ---------- - source_name : str + source_nodes : str The name of the source nodes. - target_name : str + target_nodes : str The name of the target nodes. cutoff_factor : float Factor to multiply the grid reference distance to get the cut-off radius. @@ -205,8 +203,8 @@ class CutOffEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, cutoff_factor: float): - super().__init__(source_name, target_name) + def __init__(self, source_nodes: str, target_nodes: str, cutoff_factor: float): + super().__init__(source_nodes, target_nodes) assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float" assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor @@ -228,7 +226,7 @@ def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] float The cut-off radius. """ - target_nodes = graph[self.target_name] + target_nodes = graph[self.target_nodes] mask = target_nodes[mask_attr] if mask_attr is not None else None target_grid_reference_distance = get_grid_reference_distance(target_nodes.x, mask) radius = target_grid_reference_distance * self.cutoff_factor @@ -252,8 +250,8 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor LOGGER.info( "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", self.radius * EARTH_RADIUS, - self.source_name, - self.target_name, + self.source_nodes, + self.target_nodes, ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 4bd675b..469c294 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -1,6 +1,7 @@ import logging from abc import ABC from abc import abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -14,32 +15,31 @@ LOGGER = logging.getLogger(__name__) +@dataclass class BaseNodeBuilder(ABC): """Base class for node builders.""" - def register_nodes(self, graph: HeteroData, name: str) -> None: + name: str + + def register_nodes(self, graph: HeteroData) -> None: """Register nodes in the graph. Parameters ---------- graph : HeteroData The graph to register the nodes. - name : str - The name of the nodes. """ - graph[name].x = self.get_coordinates() - graph[name].node_type = type(self).__name__ + graph[self.name].x = self.get_coordinates() + graph[self.name].node_type = type(self).__name__ return graph - def register_attributes(self, graph: HeteroData, name: str, config: Optional[DotDict] = None) -> HeteroData: + def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = None) -> HeteroData: """Register attributes in the nodes of the graph specified. Parameters ---------- graph : HeteroData The graph to register the attributes. - name : str - The name of the nodes. config : DotDict The configuration of the attributes. @@ -48,11 +48,8 @@ def register_attributes(self, graph: HeteroData, name: str, config: Optional[Dot HeteroData The graph with the registered attributes. """ - if config is None: - return graph - for attr_name, attr_config in config.items(): - graph[name][attr_name] = instantiate(attr_config).compute(graph[name]) + graph[self.name][attr_name] = instantiate(attr_config).compute(graph[self.name]) return graph @abstractmethod @@ -77,15 +74,13 @@ def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch coords = np.deg2rad(coords) return torch.tensor(coords, dtype=torch.float32) - def update_graph(self, graph: HeteroData, name: str, attr_config: Optional[DotDict] = None) -> HeteroData: + def update_graph(self, graph: HeteroData, attr_config: Optional[DotDict] = None) -> HeteroData: """Update the graph with new nodes. Parameters ---------- graph : HeteroData Input graph. - name : str - The name of the nodes. attr_config : DotDict The configuration of the attributes. @@ -94,12 +89,13 @@ def update_graph(self, graph: HeteroData, name: str, attr_config: Optional[DotDi HeteroData The graph with new nodes included. """ - graph = self.register_nodes(graph, name) + graph = self.register_nodes(graph) if attr_config is None: return graph - graph = self.register_attributes(graph, name, attr_config) + graph = self.register_attributes(graph, attr_config) + return graph @@ -123,9 +119,10 @@ class ZarrDatasetNodes(BaseNodeBuilder): Update the graph with new nodes and attributes. """ - def __init__(self, dataset: DotDict) -> None: + def __init__(self, dataset: DotDict, name: str) -> None: LOGGER.info("Reading the dataset from %s.", dataset) self.ds = open_dataset(dataset) + super().__init__(name) def get_coordinates(self) -> torch.Tensor: """Get the coordinates of the nodes. @@ -162,7 +159,7 @@ class NPZFileNodes(BaseNodeBuilder): Update the graph with new nodes and attributes. """ - def __init__(self, resolution: str, grid_definition_path: str) -> None: + def __init__(self, resolution: str, grid_definition_path: str, name: str) -> None: """Initialize the NPZFileNodes builder. The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. @@ -177,6 +174,7 @@ def __init__(self, resolution: str, grid_definition_path: str) -> None: self.resolution = resolution self.grid_definition_path = grid_definition_path self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + super().__init__(name) def get_coordinates(self) -> torch.Tensor: """Get the coordinates of the nodes. diff --git a/tests/conftest.py b/tests/conftest.py index 290165c..b0ce308 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,28 +57,25 @@ def graph_nodes_and_edges() -> HeteroData: def config_file(tmp_path) -> tuple[str, str]: """Mock grid_definition_path with files for 3 resolutions.""" cfg = { - "nodes": { - "test_nodes": { + "nodes": [ + { + "name": "test_nodes", "node_builder": { "_target_": "anemoi.graphs.nodes.NPZFileNodes", "grid_definition_path": str(tmp_path), "resolution": "o16", }, } - }, + ], "edges": [ { - "nodes": {"source_name": "test_nodes", "target_name": "test_nodes"}, + "names": {"source_nodes": "test_nodes", "target_nodes": "test_nodes"}, "edge_builder": { "_target_": "anemoi.graphs.edges.KNNEdges", "num_nearest_neighbours": 3, }, "attributes": { - "dist_norm": { - "_target_": "anemoi.graphs.edges.attributes.EdgeLength", - "norm": "l1", - "invert": True, - }, + "dist_norm": {"_target_": "anemoi.graphs.edges.attributes.EdgeLength"}, "edge_dirs": {"_target_": "anemoi.graphs.edges.attributes.EdgeDirection"}, }, }, diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 02febca..119bad7 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -11,7 +11,7 @@ def test_init(mock_grids_path: tuple[str, int], resolution: str): """Test NPZFileNodes initialization.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path, name="test_nodes") assert isinstance(node_builder, NPZFileNodes) @@ -20,13 +20,13 @@ def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution """Test NPZFileNodes initialization with invalid resolution.""" grid_definition_path, _ = mock_grids_path with pytest.raises(FileNotFoundError): - NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + NPZFileNodes(resolution, grid_definition_path=grid_definition_path, name="test_nodes") def test_fail_init_wrong_path(): """Test NPZFileNodes initialization with invalid path.""" with pytest.raises(FileNotFoundError): - NPZFileNodes("o16", "invalid_path") + NPZFileNodes("o16", "invalid_path", name="test_nodes") @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) @@ -34,9 +34,9 @@ def test_register_nodes(mock_grids_path: str, resolution: str): """Test NPZFileNodes register correctly the nodes.""" graph = HeteroData() grid_definition_path, num_nodes = mock_grids_path - node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path, name="test_nodes") - graph = node_builder.register_nodes(graph, "test_nodes") + graph = node_builder.register_nodes(graph) assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) @@ -48,10 +48,10 @@ def test_register_nodes(mock_grids_path: str, resolution: str): def test_register_attributes(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): """Test NPZFileNodes register correctly the weights.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZFileNodes("o16", grid_definition_path=grid_definition_path) + node_builder = NPZFileNodes("o16", grid_definition_path=grid_definition_path, name="test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} - graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) + graph = node_builder.register_attributes(graph_with_nodes, config) assert graph["test_nodes"]["test_attr"] is not None assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index 357bc06..3b0ca08 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -11,7 +11,7 @@ def test_init(mocker, mock_zarr_dataset): """Test ZarrDatasetNodes initialization.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") assert isinstance(node_builder, builder.BaseNodeBuilder) assert isinstance(node_builder, builder.ZarrDatasetNodes) @@ -19,16 +19,16 @@ def test_init(mocker, mock_zarr_dataset): def test_fail_init(): """Test ZarrDatasetNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): - builder.ZarrDatasetNodes("invalid_path.zarr") + builder.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") def test_register_nodes(mocker, mock_zarr_dataset): """Test ZarrDatasetNodes register correctly the nodes.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") graph = HeteroData() - graph = node_builder.register_nodes(graph, "test_nodes") + graph = node_builder.register_nodes(graph) assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) @@ -40,10 +40,10 @@ def test_register_nodes(mocker, mock_zarr_dataset): def test_register_attributes(mocker, graph_with_nodes: HeteroData, attr_class): """Test ZarrDatasetNodes register correctly the weights.""" mocker.patch.object(builder, "open_dataset", return_value=None) - node_builder = builder.ZarrDatasetNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} - graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) + graph = node_builder.register_attributes(graph_with_nodes, config) assert graph["test_nodes"]["test_attr"] is not None assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) From 1b20845673d575281b91ecd474ff0a3d283c6425 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 8 Jul 2024 09:57:23 +0000 Subject: [PATCH 061/156] new input config --- src/anemoi/graphs/create.py | 2 +- src/anemoi/graphs/edges/attributes.py | 3 +- src/anemoi/graphs/edges/builder.py | 50 ++++++++++++++------------- src/anemoi/graphs/nodes/attributes.py | 11 +++++- src/anemoi/graphs/nodes/builder.py | 12 ++++--- tests/conftest.py | 3 +- tests/edges/test_edge_attributes.py | 6 ++-- tests/nodes/test_node_attributes.py | 8 ++--- tests/nodes/test_npz.py | 10 +++--- 9 files changed, 60 insertions(+), 45 deletions(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 8935c96..06f6910 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -52,7 +52,7 @@ def generate_graph(self) -> HeteroData: ) for edges_cfg in self.config.edges: - graph = instantiate(edges_cfg.edge_builder, **edges_cfg.names).update_graph( + graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph( graph, edges_cfg.get("attributes", {}) ) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index cf08e47..1d43102 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -33,8 +33,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: return torch.tensor(normed_values, dtype=torch.float32) - def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> torch.Tensor: + def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor: """Compute the edge attributes.""" + source_name, _, target_name = edges_name assert ( source_name in graph.node_types ), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index f9b517d..17ba4fc 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -1,7 +1,6 @@ import logging from abc import ABC from abc import abstractmethod -from dataclasses import dataclass from typing import Optional import numpy as np @@ -18,19 +17,24 @@ LOGGER = logging.getLogger(__name__) -@dataclass class BaseEdgeBuilder(ABC): """Base class for edge builders.""" - source_nodes: str - target_nodes: str + def __init__(self, source_name: str, target_name: str): + self.source_name = source_name + self.target_name = target_name + + @property + def name(self) -> tuple[str, str, str]: + """Name of the edge subgraph.""" + return self.source_name, "to", self.target_name @abstractmethod def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ... def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: """Prepare nodes information.""" - return graph[self.source_nodes], graph[self.target_nodes] + return graph[self.source_name], graph[self.target_name] def get_edge_index(self, graph: HeteroData) -> torch.Tensor: """Get the edge indices of source and target nodes. @@ -52,7 +56,7 @@ def get_edge_index(self, graph: HeteroData) -> torch.Tensor: # Get source & target indices of the edges edge_index = np.stack([adjmat.col, adjmat.row], axis=0) - return torch.from_numpy(edge_index).to(torch.float32) + return torch.from_numpy(edge_index).to(torch.int32) def register_edges(self, graph: HeteroData) -> HeteroData: """Register edges in the graph. @@ -67,8 +71,8 @@ def register_edges(self, graph: HeteroData) -> HeteroData: HeteroData The graph with the registered edges. """ - graph[(self.source_nodes, "to", self.target_nodes)].edge_index = self.get_edge_index(graph) - graph[(self.source_nodes, "to", self.target_nodes)].edge_type = type(self).__name__ + graph[self.name].edge_index = self.get_edge_index(graph) + graph[self.name].edge_type = type(self).__name__ return graph def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: @@ -87,9 +91,7 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: The graph with the registered attributes. """ for attr_name, attr_config in config.items(): - graph[self.source_nodes, "to", self.target_nodes][attr_name] = instantiate(attr_config).compute( - graph, self.source_nodes, self.target_nodes - ) + graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) return graph def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: @@ -122,9 +124,9 @@ class KNNEdges(BaseEdgeBuilder): Attributes ---------- - source_nodes : str + source_name : str The name of the source nodes. - target_nodes : str + target_name : str The name of the target nodes. num_nearest_neighbours : int Number of nearest neighbours. @@ -141,8 +143,8 @@ class KNNEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_nodes: str, target_nodes: str, num_nearest_neighbours: int): - super().__init__(source_nodes, target_nodes) + def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): + super().__init__(source_name, target_name) assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" self.num_nearest_neighbours = num_nearest_neighbours @@ -161,8 +163,8 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra LOGGER.info( "Using KNN-Edges (with %d nearest neighbours) between %s and %s.", self.num_nearest_neighbours, - self.source_nodes, - self.target_nodes, + self.source_name, + self.target_name, ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) @@ -180,9 +182,9 @@ class CutOffEdges(BaseEdgeBuilder): Attributes ---------- - source_nodes : str + source_name : str The name of the source nodes. - target_nodes : str + target_name : str The name of the target nodes. cutoff_factor : float Factor to multiply the grid reference distance to get the cut-off radius. @@ -203,8 +205,8 @@ class CutOffEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_nodes: str, target_nodes: str, cutoff_factor: float): - super().__init__(source_nodes, target_nodes) + def __init__(self, source_name: str, target_name: str, cutoff_factor: float): + super().__init__(source_name, target_name) assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float" assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor @@ -226,7 +228,7 @@ def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] float The cut-off radius. """ - target_nodes = graph[self.target_nodes] + target_nodes = graph[self.target_name] mask = target_nodes[mask_attr] if mask_attr is not None else None target_grid_reference_distance = get_grid_reference_distance(target_nodes.x, mask) radius = target_grid_reference_distance * self.cutoff_factor @@ -250,8 +252,8 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor LOGGER.info( "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", self.radius * EARTH_RADIUS, - self.source_nodes, - self.target_nodes, + self.source_name, + self.target_name, ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index 0bd4d2e..89957aa 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -7,6 +7,7 @@ import numpy as np import torch from scipy.spatial import SphericalVoronoi +from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian @@ -33,14 +34,22 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: return torch.tensor(norm_values, dtype=torch.float32) - def compute(self, nodes: NodeStorage, *args, **kwargs) -> torch.Tensor: + def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch.Tensor: """Get the node weights. + Parameters + ---------- + graph : HeteroData + Graph. + nodes_name : str + Name of the nodes. + Returns ------- torch.Tensor Weights associated to the nodes. """ + nodes = graph[nodes_name] weights = self.get_raw_values(nodes, *args, **kwargs) return self.post_process(weights) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 469c294..6ff37a1 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -1,7 +1,6 @@ import logging from abc import ABC from abc import abstractmethod -from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -15,11 +14,14 @@ LOGGER = logging.getLogger(__name__) -@dataclass class BaseNodeBuilder(ABC): - """Base class for node builders.""" + """Base class for node builders. - name: str + The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. + """ + + def __init__(self, name: str) -> None: + self.name = name def register_nodes(self, graph: HeteroData) -> None: """Register nodes in the graph. @@ -49,7 +51,7 @@ def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = Non The graph with the registered attributes. """ for attr_name, attr_config in config.items(): - graph[self.name][attr_name] = instantiate(attr_config).compute(graph[self.name]) + graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) return graph @abstractmethod diff --git a/tests/conftest.py b/tests/conftest.py index b0ce308..1dc76de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,7 +69,8 @@ def config_file(tmp_path) -> tuple[str, str]: ], "edges": [ { - "names": {"source_nodes": "test_nodes", "target_nodes": "test_nodes"}, + "source_name": "test_nodes", + "target_name": "test_nodes", "edge_builder": { "_target_": "anemoi.graphs.edges.KNNEdges", "num_nearest_neighbours": 3, diff --git a/tests/edges/test_edge_attributes.py b/tests/edges/test_edge_attributes.py index 5c49119..40cba1c 100644 --- a/tests/edges/test_edge_attributes.py +++ b/tests/edges/test_edge_attributes.py @@ -10,7 +10,7 @@ def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features: bool): """Test EdgeDirection compute method.""" edge_attr_builder = EdgeDirection(norm=norm, luse_rotated_features=luse_rotated_features) - edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, "test_nodes", "test_nodes") + edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, ("test_nodes", "to", "test_nodes")) assert isinstance(edge_attr, torch.Tensor) @@ -18,7 +18,7 @@ def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features def test_edge_lengths(graph_nodes_and_edges, norm): """Test EdgeLength compute method.""" edge_attr_builder = EdgeLength(norm=norm) - edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, "test_nodes", "test_nodes") + edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, ("test_nodes", "to", "test_nodes")) assert isinstance(edge_attr, torch.Tensor) @@ -26,4 +26,4 @@ def test_edge_lengths(graph_nodes_and_edges, norm): def test_fail_edge_features(attribute_builder, graph_nodes_and_edges): """Test edge attribute builder fails with unknown nodes.""" with pytest.raises(AssertionError): - attribute_builder.compute(graph_nodes_and_edges, "test_nodes", "unknown_nodes") + attribute_builder.compute(graph_nodes_and_edges, ("test_nodes", "to", "unknown_nodes")) diff --git a/tests/nodes/test_node_attributes.py b/tests/nodes/test_node_attributes.py index 300585b..7347d88 100644 --- a/tests/nodes/test_node_attributes.py +++ b/tests/nodes/test_node_attributes.py @@ -10,7 +10,7 @@ def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): """Test attribute builder for UniformWeights.""" node_attr_builder = UniformWeights(norm=norm) - weights = node_attr_builder.compute(graph_with_nodes["test_nodes"]) + weights = node_attr_builder.compute(graph_with_nodes, "test_nodes") assert weights is not None assert isinstance(weights, torch.Tensor) @@ -22,13 +22,13 @@ def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): """Test attribute builder for UniformWeights with invalid norm.""" with pytest.raises(ValueError): node_attr_builder = UniformWeights(norm=norm) - node_attr_builder.compute(graph_with_nodes["test_nodes"]) + node_attr_builder.compute(graph_with_nodes, "test_nodes") def test_area_weights(graph_with_nodes: HeteroData): """Test attribute builder for AreaWeights.""" node_attr_builder = AreaWeights() - weights = node_attr_builder.compute(graph_with_nodes["test_nodes"]) + weights = node_attr_builder.compute(graph_with_nodes, "test_nodes") assert weights is not None assert isinstance(weights, torch.Tensor) @@ -40,4 +40,4 @@ def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): """Test attribute builder for AreaWeights with invalid radius.""" with pytest.raises(ValueError): node_attr_builder = AreaWeights(radius=radius) - node_attr_builder.compute(graph_with_nodes["test_nodes"]) + node_attr_builder.compute(graph_with_nodes, "test_nodes") diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 119bad7..95d09c0 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -11,7 +11,7 @@ def test_init(mock_grids_path: tuple[str, int], resolution: str): """Test NPZFileNodes initialization.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path, name="test_nodes") + node_builder = NPZFileNodes(resolution, grid_definition_path, "test_nodes") assert isinstance(node_builder, NPZFileNodes) @@ -20,13 +20,13 @@ def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution """Test NPZFileNodes initialization with invalid resolution.""" grid_definition_path, _ = mock_grids_path with pytest.raises(FileNotFoundError): - NPZFileNodes(resolution, grid_definition_path=grid_definition_path, name="test_nodes") + NPZFileNodes(resolution, grid_definition_path, "test_nodes") def test_fail_init_wrong_path(): """Test NPZFileNodes initialization with invalid path.""" with pytest.raises(FileNotFoundError): - NPZFileNodes("o16", "invalid_path", name="test_nodes") + NPZFileNodes("o16", "invalid_path", "test_nodes") @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) @@ -34,7 +34,7 @@ def test_register_nodes(mock_grids_path: str, resolution: str): """Test NPZFileNodes register correctly the nodes.""" graph = HeteroData() grid_definition_path, num_nodes = mock_grids_path - node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path, name="test_nodes") + node_builder = NPZFileNodes(resolution, grid_definition_path, "test_nodes") graph = node_builder.register_nodes(graph) @@ -48,7 +48,7 @@ def test_register_nodes(mock_grids_path: str, resolution: str): def test_register_attributes(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): """Test NPZFileNodes register correctly the weights.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZFileNodes("o16", grid_definition_path=grid_definition_path, name="test_nodes") + node_builder = NPZFileNodes("o16", grid_definition_path, "test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, config) From 4e62431bfdbead408805591516222efb8f1169d1 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 8 Jul 2024 10:02:32 +0000 Subject: [PATCH 062/156] new default --- src/anemoi/graphs/edges/attributes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 1d43102..d49cc53 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -71,7 +71,7 @@ class EdgeDirection(BaseEdgeAttribute): Compute directional attributes. """ - norm: Optional[str] = None + norm: str = "unit-std" luse_rotated_features: bool = True def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: @@ -117,8 +117,8 @@ class EdgeLength(BaseEdgeAttribute): Compute edge lengths attributes. """ - norm: str = "l1" - invert: bool = True + norm: str = "unit-std" + invert: bool = False def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: """Compute haversine distance (in kilometers) between nodes connected by edges. From d25c47b6e916b3b838227c879119578fda2ce859 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 24 Jun 2024 14:29:35 +0000 Subject: [PATCH 063/156] feat: Initial implementation of global graphs Co-authored by: Mario Santa Cruz --- .gitignore | 2 + src/anemoi/graphs/edges/connections.py | 131 +++++++++++++++++++++++++ src/anemoi/graphs/generate.py | 36 +++++++ src/anemoi/graphs/nodes/nodes.py | 93 ++++++++++++++++++ src/anemoi/graphs/nodes/weights.py | 62 ++++++++++++ 5 files changed, 324 insertions(+) create mode 100644 src/anemoi/graphs/edges/connections.py create mode 100644 src/anemoi/graphs/generate.py create mode 100644 src/anemoi/graphs/nodes/nodes.py create mode 100644 src/anemoi/graphs/nodes/weights.py diff --git a/.gitignore b/.gitignore index 1b49006..54ecc69 100644 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,5 @@ _version.py /config* *.pt + +/config* \ No newline at end of file diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/connections.py new file mode 100644 index 0000000..dcface3 --- /dev/null +++ b/src/anemoi/graphs/edges/connections.py @@ -0,0 +1,131 @@ +from abc import abstractmethod +from dataclasses import dataclass +from typing import Optional + +import networkx as nx +import numpy as np +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize +from torch_geometric.data import HeteroData +from torch_geometric.data.storage import NodeStorage + +from anemoi.graphs import earth_radius +from anemoi.graphs.utils import get_grid_reference_distance + +import logging + +logger = logging.getLogger(__name__) + + +class BaseEdgeBuilder: + """Base class for edge builders.""" + + def __init__(self, src_name: str, dst_name: str): + super().__init__() + self.src_name = src_name + self.dst_name = dst_name + + @abstractmethod + def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): ... + + def register_edges(self, graph, head_indices, tail_indices): + graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype(np.int32) + return graph + + def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarray): + num_edges = graph[(self.src_name, "to", self.dst_name)].num_edges + assert ( + values.shape[0] == num_edges + ), f"Number of edge features ({values.shape[0]}) must match number of edges ({num_edges})." + graph[self.src_name, "to", self.dst_name][name] = values.reshape(num_edges, -1) # TODO: Check the [name] part works + return graph + + def prepare_node_data(self, graph: HeteroData): + return graph[self.src_name], graph[self.dst_name] + + def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + # Get source and destination nodes. + src_nodes, dst_nodes = self.prepare_node_data(graph) + + # Compute adjacency matrix. + adjmat = self.get_adj_matrix(src_nodes, dst_nodes) + + # Normalize adjacency matrix. + adjmat_norm = self.normalize_adjmat(adjmat) + + # Add edges to the graph and register normed distance. + graph = self.register_edges(graph, adjmat.col, adjmat.row) + + self.register_edge_attribute(graph, "normed_dist", adjmat_norm.data) + if attrs_config is not None: + for attr_name, attr_cfg in attrs_config.items(): + attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name) + graph = self.register_edge_attribute(graph, attr_name, attr_values) + + return graph + + def normalize_adjmat(self, adjmat): + """Normalize a sparse adjacency matrix.""" + adjmat_norm = normalize(adjmat, norm="l1", axis=1) + adjmat_norm.data = 1.0 - adjmat_norm.data + return adjmat_norm + + +class KNNEdgeBuilder(BaseEdgeBuilder): + """Computes KNN based edges and adds them to the graph.""" + + def __init__(self, src_name: str, dst_name: str, num_nearest_neighbours: int): + super().__init__(src_name, dst_name) + assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" + assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" + self.num_nearest_neighbours = num_nearest_neighbours + + def get_adj_matrix(self, src_nodes: np.ndarray, dst_nodes: np.ndarray): + assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" + logger.debug( + "Using %d nearest neighbours for KNN-Edges between %s and %s.", + self.num_nearest_neighbours, + self.src_name, + self.dst_name, + ) + + nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + nearest_neighbour.fit(src_nodes.x.numpy()) + adj_matrix = nearest_neighbour.kneighbors_graph( + dst_nodes.x.numpy(), + n_neighbors=self.num_nearest_neighbours, + mode="distance", + ).tocoo() + return adj_matrix + + +class CutOffEdgeBuilder(BaseEdgeBuilder): + """Computes cut-off based edges and adds them to the graph.""" + + def __init__(self, src_name: str, dst_name: str, cutoff_factor: float): + super().__init__(src_name, dst_name) + assert isinstance(cutoff_factor, float), "Cutoff factor must be a float" + assert cutoff_factor > 0, "Cutoff factor must be positive" + self.cutoff_factor = cutoff_factor + + def get_cutoff_radius(self, dst_nodes: NodeStorage, mask_attr: Optional[torch.Tensor] = None): + mask = dst_nodes[mask_attr] if mask_attr is not None else None + dst_grid_reference_distance = get_grid_reference_distance(dst_nodes.x, mask) + radius = dst_grid_reference_distance * self.cutoff_factor + return radius + + def prepare_node_data(self, graph: HeteroData): + self.radius = self.get_cutoff_radius(graph) + return super().prepare_node_data(graph) + + def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): + logger.debug("Using cut-off radius of %.1f km.", self.radius * earth_radius) + + nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + nearest_neighbour.fit(src_nodes.x) + adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo() + return adj_matrix + diff --git a/src/anemoi/graphs/generate.py b/src/anemoi/graphs/generate.py new file mode 100644 index 0000000..e44529e --- /dev/null +++ b/src/anemoi/graphs/generate.py @@ -0,0 +1,36 @@ +from abc import ABC +from abc import abstractmethod + +import hydra +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch_geometric.data import HeteroData + +import logging + +logger = logging.getLogger(__name__) + + +def generate_graph(graph_config): + graph = HeteroData() + + for name, nodes_cfg in graph_config.nodes.items(): + graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {})) + + for edges_cfg in graph_config.edges: + graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) + + return graph + + +@hydra.main(version_base=None, config_path="../config", config_name="config") +def main(config: DictConfig): + + graph = generate_graph(config) + + return graph + + +if __name__ == "__main__": + main() diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py new file mode 100644 index 0000000..d00dd62 --- /dev/null +++ b/src/anemoi/graphs/nodes/nodes.py @@ -0,0 +1,93 @@ +from abc import abstractmethod +from pathlib import Path +from typing import Optional +from typing import Union + +import h3 +import numpy as np +import torch +from abc import ABC +from anemoi.datasets import open_dataset +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from aifs.graphs import GraphBuilder +from aifs.graphs.generate.hexagonal import create_hexagonal_nodes +from aifs.graphs.generate.icosahedral import create_icosahedral_nodes +import logging + +logger = logging.getLogger(__name__) +earth_radius = 6371.0 # km + + +def latlon_to_radians(coords: np.ndarray) -> np.ndarray: + return np.deg2rad(coords) + + +def rad_to_latlon(coords: np.ndarray) -> np.ndarray: + """Converts coordinates from radians to degrees. + + Parameters + ---------- + coords : np.ndarray + Coordinates in radians. + + Returns + ------- + np.ndarray + _description_ + """ + return np.rad2deg(coords) + + +class BaseNodeBuilder(ABC): + + def register_nodes(self, graph: HeteroData, name: str) -> None: + graph[name].x = self.get_coordinates() + graph[name].node_type = type(self).__name__ + return graph + + def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: + for nodes_attr_name, attr_cfg in config.items(): + graph[name][nodes_attr_name] = instantiate(attr_cfg).get_weights(graph[name]) + return graph + + @abstractmethod + def get_coordinates(self) -> np.ndarray: ... + + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: + coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) + coords = np.deg2rad(coords) + # TODO: type needs to be variable? + return torch.tensor(coords, dtype=torch.float32) + + def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: + graph = self.register_nodes(graph, name) + graph = self.register_attributes(graph, name, attr_config) + return graph + + +class ZarrNodes(BaseNodeBuilder): + """Nodes from Zarr dataset.""" + + def __init__(self, dataset: DotDict) -> None: + logger.info("Reading the dataset from %s.", dataset) + self.ds = open_dataset(dataset) + + def get_coordinates(self) -> torch.Tensor: + return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) + + +class NPZNodes(BaseNodeBuilder): + """Nodes from NPZ defined grids.""" + + def __init__(self, resolution: str, grid_definition_path: str) -> None: + self.resolution = resolution + self.grid_definition_path = grid_definition_path + self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + + def get_coordinates(self) -> np.ndarray: + coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) + return coords diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py new file mode 100644 index 0000000..e2249f1 --- /dev/null +++ b/src/anemoi/graphs/nodes/weights.py @@ -0,0 +1,62 @@ +from abc import ABC +from abc import abstractmethod +from typing import Optional + +import numpy as np +import torch +from torch_geometric.data.storage import NodeStorage + +from anemoi.graphs.generate.transforms import to_sphere_xyz +from scipy.spatial import SphericalVoronoi +from anemoi.graphs.normalizer import NormalizerMixin +import logging + +logger = logging.getLogger(__name__) + +class BaseWeights(ABC, NormalizerMixin): + """Base class for the weights of the nodes.""" + + def __init__(self, norm: Optional[str] = None): + self.norm = norm + + @abstractmethod + def compute(self, nodes: NodeStorage, *args, **kwargs): ... + + def get_weights(self, *args, **kwargs): + weights = self.compute(*args, **kwargs) + if weights.ndim == 1: + weights = weights[:, np.newaxis] + return self.normalize(weights) + + +class UniformWeights(BaseWeights): + """Implements a uniform weight for the nodes.""" + + def __init__(self, norm: str = "unit-max"): + self.norm = norm + + def compute(self, nodes: NodeStorage) -> np.ndarray: + return torch.ones(nodes.num_nodes) + + +class AreaWeights(BaseWeights): + """Implements the area of the nodes as the weights.""" + + def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array[0, 0, 0]): + # Weighting of the nodes + self.norm: str = norm + self.radius: float = radius + self.centre: np.ndarray = centre + + def compute(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: + # TODO: Check if works + latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] + points = to_sphere_xyz((latitudes, longitudes)) + sv = SphericalVoronoi(points, self.radius, self.centre) + area_weights = sv.calculate_areas() + logger.debug( + "There are %d of weights, which (unscaled) add up a total weight of %.2f.", + len(area_weights), + np.array(area_weights).sum(), + ) + return area_weights From 8f0415e05056eed2b00203537207bbf66c63c757 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 25 Jun 2024 11:58:42 +0000 Subject: [PATCH 064/156] add cli command --- src/anemoi/graphs/edges/attributes.py | 3 ++- src/anemoi/graphs/edges/connections.py | 17 ++++++------ src/anemoi/graphs/generate.py | 36 -------------------------- src/anemoi/graphs/nodes/nodes.py | 12 ++------- src/anemoi/graphs/nodes/weights.py | 18 ++++++------- src/anemoi/graphs/normalizer.py | 2 +- 6 files changed, 22 insertions(+), 66 deletions(-) delete mode 100644 src/anemoi/graphs/generate.py diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index d49cc53..402563e 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Optional +import logging import numpy as np import torch from torch_geometric.data import HeteroData @@ -12,7 +13,7 @@ from anemoi.graphs.normalizer import NormalizerMixin from anemoi.graphs.utils import haversine_distance -LOGGER = logging.getLogger(__name__) +logger = logging.getLogger(__name__) @dataclass diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/connections.py index dcface3..6bf057e 100644 --- a/src/anemoi/graphs/edges/connections.py +++ b/src/anemoi/graphs/edges/connections.py @@ -1,8 +1,7 @@ +import logging from abc import abstractmethod -from dataclasses import dataclass from typing import Optional -import networkx as nx import numpy as np import torch from anemoi.utils.config import DotDict @@ -15,8 +14,6 @@ from anemoi.graphs import earth_radius from anemoi.graphs.utils import get_grid_reference_distance -import logging - logger = logging.getLogger(__name__) @@ -32,7 +29,9 @@ def __init__(self, src_name: str, dst_name: str): def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): ... def register_edges(self, graph, head_indices, tail_indices): - graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype(np.int32) + graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype( + np.int32 + ) return graph def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarray): @@ -40,7 +39,9 @@ def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarr assert ( values.shape[0] == num_edges ), f"Number of edge features ({values.shape[0]}) must match number of edges ({num_edges})." - graph[self.src_name, "to", self.dst_name][name] = values.reshape(num_edges, -1) # TODO: Check the [name] part works + graph[self.src_name, "to", self.dst_name][name] = values.reshape( + num_edges, -1 + ) # TODO: Check the [name] part works return graph def prepare_node_data(self, graph: HeteroData): @@ -111,7 +112,8 @@ def __init__(self, src_name: str, dst_name: str, cutoff_factor: float): assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor - def get_cutoff_radius(self, dst_nodes: NodeStorage, mask_attr: Optional[torch.Tensor] = None): + def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] = None): + dst_nodes = graph[self.dst_name] mask = dst_nodes[mask_attr] if mask_attr is not None else None dst_grid_reference_distance = get_grid_reference_distance(dst_nodes.x, mask) radius = dst_grid_reference_distance * self.cutoff_factor @@ -128,4 +130,3 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): nearest_neighbour.fit(src_nodes.x) adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo() return adj_matrix - diff --git a/src/anemoi/graphs/generate.py b/src/anemoi/graphs/generate.py deleted file mode 100644 index e44529e..0000000 --- a/src/anemoi/graphs/generate.py +++ /dev/null @@ -1,36 +0,0 @@ -from abc import ABC -from abc import abstractmethod - -import hydra -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from omegaconf import DictConfig -from torch_geometric.data import HeteroData - -import logging - -logger = logging.getLogger(__name__) - - -def generate_graph(graph_config): - graph = HeteroData() - - for name, nodes_cfg in graph_config.nodes.items(): - graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {})) - - for edges_cfg in graph_config.edges: - graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) - - return graph - - -@hydra.main(version_base=None, config_path="../config", config_name="config") -def main(config: DictConfig): - - graph = generate_graph(config) - - return graph - - -if __name__ == "__main__": - main() diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py index d00dd62..8de951f 100644 --- a/src/anemoi/graphs/nodes/nodes.py +++ b/src/anemoi/graphs/nodes/nodes.py @@ -1,23 +1,15 @@ +import logging +from abc import ABC from abc import abstractmethod from pathlib import Path -from typing import Optional -from typing import Union -import h3 import numpy as np import torch -from abc import ABC from anemoi.datasets import open_dataset from anemoi.utils.config import DotDict from hydra.utils import instantiate -from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData -from aifs.graphs import GraphBuilder -from aifs.graphs.generate.hexagonal import create_hexagonal_nodes -from aifs.graphs.generate.icosahedral import create_icosahedral_nodes -import logging - logger = logging.getLogger(__name__) earth_radius = 6371.0 # km diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py index e2249f1..3afe523 100644 --- a/src/anemoi/graphs/nodes/weights.py +++ b/src/anemoi/graphs/nodes/weights.py @@ -1,18 +1,19 @@ +import logging from abc import ABC from abc import abstractmethod from typing import Optional import numpy as np import torch +from scipy.spatial import SphericalVoronoi from torch_geometric.data.storage import NodeStorage from anemoi.graphs.generate.transforms import to_sphere_xyz -from scipy.spatial import SphericalVoronoi from anemoi.graphs.normalizer import NormalizerMixin -import logging logger = logging.getLogger(__name__) + class BaseWeights(ABC, NormalizerMixin): """Base class for the weights of the nodes.""" @@ -32,9 +33,6 @@ def get_weights(self, *args, **kwargs): class UniformWeights(BaseWeights): """Implements a uniform weight for the nodes.""" - def __init__(self, norm: str = "unit-max"): - self.norm = norm - def compute(self, nodes: NodeStorage) -> np.ndarray: return torch.ones(nodes.num_nodes) @@ -42,14 +40,14 @@ def compute(self, nodes: NodeStorage) -> np.ndarray: class AreaWeights(BaseWeights): """Implements the area of the nodes as the weights.""" - def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array[0, 0, 0]): + def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])): + super().__init__(norm=norm) + # Weighting of the nodes - self.norm: str = norm - self.radius: float = radius - self.centre: np.ndarray = centre + self.radius = radius + self.centre = centre def compute(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: - # TODO: Check if works latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] points = to_sphere_xyz((latitudes, longitudes)) sv = SphericalVoronoi(points, self.radius, self.centre) diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index c625417..bdadfab 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -2,7 +2,7 @@ import numpy as np -LOGGER = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class NormalizerMixin: From 22fba0e761aec26367d58e1b742acc519e50b289 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 25 Jun 2024 12:03:20 +0000 Subject: [PATCH 065/156] Ignore .pt files --- .gitignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 54ecc69..d2278fa 100644 --- a/.gitignore +++ b/.gitignore @@ -188,6 +188,4 @@ _version.py *.code-workspace /config* -*.pt - -/config* \ No newline at end of file +*.pt \ No newline at end of file From 2879c7ca27d488e66601f7a14ce1c800a725fce3 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 25 Jun 2024 12:32:00 +0000 Subject: [PATCH 066/156] run pre-commit --- .gitignore | 2 +- src/anemoi/graphs/edges/attributes.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index d2278fa..1b49006 100644 --- a/.gitignore +++ b/.gitignore @@ -188,4 +188,4 @@ _version.py *.code-workspace /config* -*.pt \ No newline at end of file +*.pt diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 402563e..b93811e 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Optional -import logging import numpy as np import torch from torch_geometric.data import HeteroData From 1ca063375a2189da1e322b493ac887438b01d636 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 26 Jun 2024 11:12:52 +0000 Subject: [PATCH 067/156] docstring + log erros --- src/anemoi/graphs/nodes/nodes.py | 23 +---------------------- src/anemoi/graphs/nodes/weights.py | 7 ++++--- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py index 8de951f..886125e 100644 --- a/src/anemoi/graphs/nodes/nodes.py +++ b/src/anemoi/graphs/nodes/nodes.py @@ -11,30 +11,10 @@ from torch_geometric.data import HeteroData logger = logging.getLogger(__name__) -earth_radius = 6371.0 # km - - -def latlon_to_radians(coords: np.ndarray) -> np.ndarray: - return np.deg2rad(coords) - - -def rad_to_latlon(coords: np.ndarray) -> np.ndarray: - """Converts coordinates from radians to degrees. - - Parameters - ---------- - coords : np.ndarray - Coordinates in radians. - - Returns - ------- - np.ndarray - _description_ - """ - return np.rad2deg(coords) class BaseNodeBuilder(ABC): + """Base class for node builders.""" def register_nodes(self, graph: HeteroData, name: str) -> None: graph[name].x = self.get_coordinates() @@ -52,7 +32,6 @@ def get_coordinates(self) -> np.ndarray: ... def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) coords = np.deg2rad(coords) - # TODO: type needs to be variable? return torch.tensor(coords, dtype=torch.float32) def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py index 3afe523..25419cc 100644 --- a/src/anemoi/graphs/nodes/weights.py +++ b/src/anemoi/graphs/nodes/weights.py @@ -23,18 +23,19 @@ def __init__(self, norm: Optional[str] = None): @abstractmethod def compute(self, nodes: NodeStorage, *args, **kwargs): ... - def get_weights(self, *args, **kwargs): + def get_weights(self, *args, **kwargs) -> torch.Tensor: weights = self.compute(*args, **kwargs) if weights.ndim == 1: weights = weights[:, np.newaxis] - return self.normalize(weights) + norm_weights = self.normalize(weights) + return torch.tensor(norm_weights, dtype=torch.float32) class UniformWeights(BaseWeights): """Implements a uniform weight for the nodes.""" def compute(self, nodes: NodeStorage) -> np.ndarray: - return torch.ones(nodes.num_nodes) + return np.ones(nodes.num_nodes) class AreaWeights(BaseWeights): From 277c231ee7ab5c99eac73fc0b87e62ab6cab7158 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 26 Jun 2024 13:18:59 +0000 Subject: [PATCH 068/156] initial tests --- tests/edges/test_attributes.py | 20 +++++++++++++ tests/nodes/test_weights.py | 53 ++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 tests/edges/test_attributes.py create mode 100644 tests/nodes/test_weights.py diff --git a/tests/edges/test_attributes.py b/tests/edges/test_attributes.py new file mode 100644 index 0000000..dcd756d --- /dev/null +++ b/tests/edges/test_attributes.py @@ -0,0 +1,20 @@ +import pytest +import torch + +from anemoi.graphs.edges.attributes import DirectionalFeatures + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"]) +@pytest.mark.parametrize("luse_rotated_features", [True, False]) +def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features: bool): + """Test DirectionalFeatures compute method.""" + edge_attr_builder = DirectionalFeatures(norm=norm, luse_rotated_features=luse_rotated_features) + edge_attr = edge_attr_builder(graph_nodes_and_edges, "test_nodes", "test_nodes") + assert isinstance(edge_attr, torch.Tensor) + + +def test_fail_directional_features(graph_nodes_and_edges): + """Test DirectionalFeatures compute method.""" + edge_attr_builder = DirectionalFeatures() + with pytest.raises(AttributeError): + edge_attr_builder(graph_nodes_and_edges, "test_nodes", "unknown_nodes") diff --git a/tests/nodes/test_weights.py b/tests/nodes/test_weights.py new file mode 100644 index 0000000..db80dce --- /dev/null +++ b/tests/nodes/test_weights.py @@ -0,0 +1,53 @@ +import numpy as np +import pytest +import torch +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + + +@pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"]) +def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): + """Test NPZNodes register correctly the weights.""" + config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} + + weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + + assert weights is not None + assert isinstance(weights, torch.Tensor) + assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] + + +@pytest.mark.parametrize("norm", ["l3", "invalide"]) +def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): + """Test NPZNodes register correctly the weights.""" + config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} + + with pytest.raises(ValueError): + instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + + +def test_area_weights(graph_with_nodes: HeteroData): + """Test NPZNodes register correctly the weights.""" + config = { + "_target_": "anemoi.graphs.nodes.weights.AreaWeights", + "radius": 1.0, + "centre": np.array([0, 0, 0]), + } + + weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + + assert weights is not None + assert isinstance(weights, torch.Tensor) + assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] + + +@pytest.mark.parametrize("radius", [-1.0, "hello", None]) +def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): + config = { + "_target_": "anemoi.graphs.nodes.weights.AreaWeights", + "radius": radius, + "centre": np.array([0, 0, 0]), + } + + with pytest.raises(ValueError): + instantiate(config).get_weights(graph_with_nodes["test_nodes"]) From 7a4e9cbb7a45b9fc6c37f67da25bfd5633a67c7f Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 26 Jun 2024 14:22:50 +0000 Subject: [PATCH 069/156] feat: initial version of AttributeBuilder --- src/anemoi/graphs/edges/attributes.py | 33 +++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index b93811e..bd8b9a0 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -4,9 +4,12 @@ from dataclasses import dataclass from typing import Optional +import torch +from anemoi.utils.config import DotDict import numpy as np import torch from torch_geometric.data import HeteroData +from hydra.utils import instantiate from anemoi.graphs.edges.directional import directional_edge_features from anemoi.graphs.normalizer import NormalizerMixin @@ -14,6 +17,36 @@ logger = logging.getLogger(__name__) +class AttributeBuilder(): + + def transform(self, graph: HeteroData, graph_config: DotDict): + + for name, nodes_cfg in graph_config.nodes.items(): + graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {})) + for edges_cfg in graph_config.edges: + graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) + return graph + + def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict): + assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph." + for attr_name, attr_cfg in node_config.items(): + graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) + return graph + + def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict): + + for attr_name, attr_cfg in edge_config.items(): + attr_values = instantiate(attr_cfg).compute(graph, src_name, dst_name) + graph = self.register_edge_attribute(graph, src_name, dst_name, attr_name, attr_values) + return graph + + def register_edge_attribute(self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor): + num_edges = graph[(src_name, "to", dst_name)].num_edges + assert ( attr_values.shape[0] == num_edges), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})." + + graph[(src_name, "to", dst_name)][attr_name] = attr_values + return graph + @dataclass class BaseEdgeAttribute(ABC, NormalizerMixin): From 96a6ed51e12a8b8247c757b10ddb64c17bd25666 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 26 Jun 2024 14:40:06 +0000 Subject: [PATCH 070/156] refactor: separate into node edge attribute builders --- src/anemoi/graphs/edges/attributes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index bd8b9a0..453b644 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -17,15 +17,12 @@ logger = logging.getLogger(__name__) -class AttributeBuilder(): +class NodeAttributeBuilder(): def transform(self, graph: HeteroData, graph_config: DotDict): for name, nodes_cfg in graph_config.nodes.items(): graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {})) - for edges_cfg in graph_config.edges: - graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) - return graph def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict): assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph." @@ -33,6 +30,13 @@ def register_node_attributes(self, graph: HeteroData, node_name: str, node_confi graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) return graph +class EdgeAttributeBuilder(): + + def transform(self, graph: HeteroData, graph_config: DotDict): + for edges_cfg in graph_config.edges: + graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) + return graph + def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict): for attr_name, attr_cfg in edge_config.items(): From 1c88ee8d659b3799f608b67936a63c531ec03dc3 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 08:05:05 +0000 Subject: [PATCH 071/156] feat: edge_length moved to edges/attributes.py --- src/anemoi/graphs/create.py | 5 ++++ src/anemoi/graphs/edges/attributes.py | 3 --- src/anemoi/graphs/edges/connections.py | 15 ++---------- tests/nodes/test_weights.py | 33 +++++++++----------------- 4 files changed, 18 insertions(+), 38 deletions(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index bf3cb0c..cb38418 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -81,3 +81,8 @@ def _path_readable(self) -> bool: return True except FileNotFoundError: return False + + +if __name__ == "__main__": + creator = GraphCreator(config="/home/ecm1924/GitRepos/anemoi-graphs/recipe.yaml", path="graph.pt") + creator.create() diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 453b644..f54510b 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -4,12 +4,9 @@ from dataclasses import dataclass from typing import Optional -import torch -from anemoi.utils.config import DotDict import numpy as np import torch from torch_geometric.data import HeteroData -from hydra.utils import instantiate from anemoi.graphs.edges.directional import directional_edge_features from anemoi.graphs.normalizer import NormalizerMixin diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/connections.py index 6bf057e..49080ce 100644 --- a/src/anemoi/graphs/edges/connections.py +++ b/src/anemoi/graphs/edges/connections.py @@ -7,11 +7,10 @@ from anemoi.utils.config import DotDict from hydra.utils import instantiate from sklearn.neighbors import NearestNeighbors -from sklearn.preprocessing import normalize from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage -from anemoi.graphs import earth_radius +from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.utils import get_grid_reference_distance logger = logging.getLogger(__name__) @@ -54,13 +53,9 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) - # Compute adjacency matrix. adjmat = self.get_adj_matrix(src_nodes, dst_nodes) - # Normalize adjacency matrix. - adjmat_norm = self.normalize_adjmat(adjmat) - # Add edges to the graph and register normed distance. graph = self.register_edges(graph, adjmat.col, adjmat.row) - self.register_edge_attribute(graph, "normed_dist", adjmat_norm.data) if attrs_config is not None: for attr_name, attr_cfg in attrs_config.items(): attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name) @@ -68,12 +63,6 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) - return graph - def normalize_adjmat(self, adjmat): - """Normalize a sparse adjacency matrix.""" - adjmat_norm = normalize(adjmat, norm="l1", axis=1) - adjmat_norm.data = 1.0 - adjmat_norm.data - return adjmat_norm - class KNNEdgeBuilder(BaseEdgeBuilder): """Computes KNN based edges and adds them to the graph.""" @@ -124,7 +113,7 @@ def prepare_node_data(self, graph: HeteroData): return super().prepare_node_data(graph) def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): - logger.debug("Using cut-off radius of %.1f km.", self.radius * earth_radius) + logger.debug("Using cut-off radius of %.1f km.", self.radius * EARTH_RADIUS) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) nearest_neighbour.fit(src_nodes.x) diff --git a/tests/nodes/test_weights.py b/tests/nodes/test_weights.py index db80dce..71e54fa 100644 --- a/tests/nodes/test_weights.py +++ b/tests/nodes/test_weights.py @@ -1,16 +1,16 @@ -import numpy as np import pytest import torch -from hydra.utils import instantiate from torch_geometric.data import HeteroData +from anemoi.graphs.nodes.weights import AreaWeights +from anemoi.graphs.nodes.weights import UniformWeights + @pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"]) def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): """Test NPZNodes register correctly the weights.""" - config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} - - weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = UniformWeights(norm=norm) + weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) assert weights is not None assert isinstance(weights, torch.Tensor) @@ -20,21 +20,15 @@ def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): @pytest.mark.parametrize("norm", ["l3", "invalide"]) def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): """Test NPZNodes register correctly the weights.""" - config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} - with pytest.raises(ValueError): - instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = UniformWeights(norm=norm) + node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) def test_area_weights(graph_with_nodes: HeteroData): """Test NPZNodes register correctly the weights.""" - config = { - "_target_": "anemoi.graphs.nodes.weights.AreaWeights", - "radius": 1.0, - "centre": np.array([0, 0, 0]), - } - - weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = AreaWeights() + weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) assert weights is not None assert isinstance(weights, torch.Tensor) @@ -43,11 +37,6 @@ def test_area_weights(graph_with_nodes: HeteroData): @pytest.mark.parametrize("radius", [-1.0, "hello", None]) def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): - config = { - "_target_": "anemoi.graphs.nodes.weights.AreaWeights", - "radius": radius, - "centre": np.array([0, 0, 0]), - } - with pytest.raises(ValueError): - instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = AreaWeights(radius=radius) + node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) From 92988d7b141fba43a46414df982e47f8e5001d94 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 08:05:34 +0000 Subject: [PATCH 072/156] remove __init__ --- src/anemoi/graphs/create.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index cb38418..bf3cb0c 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -81,8 +81,3 @@ def _path_readable(self) -> bool: return True except FileNotFoundError: return False - - -if __name__ == "__main__": - creator = GraphCreator(config="/home/ecm1924/GitRepos/anemoi-graphs/recipe.yaml", path="graph.pt") - creator.create() From de074886c3f85e11e7fa9c22b6fbdf4453707888 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 27 Jun 2024 14:54:17 +0000 Subject: [PATCH 073/156] bugfix (encoder edge lengths) + refector --- src/anemoi/graphs/edges/connections.py | 121 ------------------------- src/anemoi/graphs/nodes/nodes.py | 64 ------------- 2 files changed, 185 deletions(-) delete mode 100644 src/anemoi/graphs/edges/connections.py delete mode 100644 src/anemoi/graphs/nodes/nodes.py diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/connections.py deleted file mode 100644 index 49080ce..0000000 --- a/src/anemoi/graphs/edges/connections.py +++ /dev/null @@ -1,121 +0,0 @@ -import logging -from abc import abstractmethod -from typing import Optional - -import numpy as np -import torch -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from sklearn.neighbors import NearestNeighbors -from torch_geometric.data import HeteroData -from torch_geometric.data.storage import NodeStorage - -from anemoi.graphs import EARTH_RADIUS -from anemoi.graphs.utils import get_grid_reference_distance - -logger = logging.getLogger(__name__) - - -class BaseEdgeBuilder: - """Base class for edge builders.""" - - def __init__(self, src_name: str, dst_name: str): - super().__init__() - self.src_name = src_name - self.dst_name = dst_name - - @abstractmethod - def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): ... - - def register_edges(self, graph, head_indices, tail_indices): - graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype( - np.int32 - ) - return graph - - def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarray): - num_edges = graph[(self.src_name, "to", self.dst_name)].num_edges - assert ( - values.shape[0] == num_edges - ), f"Number of edge features ({values.shape[0]}) must match number of edges ({num_edges})." - graph[self.src_name, "to", self.dst_name][name] = values.reshape( - num_edges, -1 - ) # TODO: Check the [name] part works - return graph - - def prepare_node_data(self, graph: HeteroData): - return graph[self.src_name], graph[self.dst_name] - - def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: - # Get source and destination nodes. - src_nodes, dst_nodes = self.prepare_node_data(graph) - - # Compute adjacency matrix. - adjmat = self.get_adj_matrix(src_nodes, dst_nodes) - - # Add edges to the graph and register normed distance. - graph = self.register_edges(graph, adjmat.col, adjmat.row) - - if attrs_config is not None: - for attr_name, attr_cfg in attrs_config.items(): - attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name) - graph = self.register_edge_attribute(graph, attr_name, attr_values) - - return graph - - -class KNNEdgeBuilder(BaseEdgeBuilder): - """Computes KNN based edges and adds them to the graph.""" - - def __init__(self, src_name: str, dst_name: str, num_nearest_neighbours: int): - super().__init__(src_name, dst_name) - assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" - assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" - self.num_nearest_neighbours = num_nearest_neighbours - - def get_adj_matrix(self, src_nodes: np.ndarray, dst_nodes: np.ndarray): - assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" - logger.debug( - "Using %d nearest neighbours for KNN-Edges between %s and %s.", - self.num_nearest_neighbours, - self.src_name, - self.dst_name, - ) - - nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(src_nodes.x.numpy()) - adj_matrix = nearest_neighbour.kneighbors_graph( - dst_nodes.x.numpy(), - n_neighbors=self.num_nearest_neighbours, - mode="distance", - ).tocoo() - return adj_matrix - - -class CutOffEdgeBuilder(BaseEdgeBuilder): - """Computes cut-off based edges and adds them to the graph.""" - - def __init__(self, src_name: str, dst_name: str, cutoff_factor: float): - super().__init__(src_name, dst_name) - assert isinstance(cutoff_factor, float), "Cutoff factor must be a float" - assert cutoff_factor > 0, "Cutoff factor must be positive" - self.cutoff_factor = cutoff_factor - - def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] = None): - dst_nodes = graph[self.dst_name] - mask = dst_nodes[mask_attr] if mask_attr is not None else None - dst_grid_reference_distance = get_grid_reference_distance(dst_nodes.x, mask) - radius = dst_grid_reference_distance * self.cutoff_factor - return radius - - def prepare_node_data(self, graph: HeteroData): - self.radius = self.get_cutoff_radius(graph) - return super().prepare_node_data(graph) - - def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): - logger.debug("Using cut-off radius of %.1f km.", self.radius * EARTH_RADIUS) - - nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(src_nodes.x) - adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo() - return adj_matrix diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py deleted file mode 100644 index 886125e..0000000 --- a/src/anemoi/graphs/nodes/nodes.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging -from abc import ABC -from abc import abstractmethod -from pathlib import Path - -import numpy as np -import torch -from anemoi.datasets import open_dataset -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from torch_geometric.data import HeteroData - -logger = logging.getLogger(__name__) - - -class BaseNodeBuilder(ABC): - """Base class for node builders.""" - - def register_nodes(self, graph: HeteroData, name: str) -> None: - graph[name].x = self.get_coordinates() - graph[name].node_type = type(self).__name__ - return graph - - def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: - for nodes_attr_name, attr_cfg in config.items(): - graph[name][nodes_attr_name] = instantiate(attr_cfg).get_weights(graph[name]) - return graph - - @abstractmethod - def get_coordinates(self) -> np.ndarray: ... - - def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: - coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) - coords = np.deg2rad(coords) - return torch.tensor(coords, dtype=torch.float32) - - def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: - graph = self.register_nodes(graph, name) - graph = self.register_attributes(graph, name, attr_config) - return graph - - -class ZarrNodes(BaseNodeBuilder): - """Nodes from Zarr dataset.""" - - def __init__(self, dataset: DotDict) -> None: - logger.info("Reading the dataset from %s.", dataset) - self.ds = open_dataset(dataset) - - def get_coordinates(self) -> torch.Tensor: - return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) - - -class NPZNodes(BaseNodeBuilder): - """Nodes from NPZ defined grids.""" - - def __init__(self, resolution: str, grid_definition_path: str) -> None: - self.resolution = resolution - self.grid_definition_path = grid_definition_path - self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") - - def get_coordinates(self) -> np.ndarray: - coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) - return coords From c900bfd670e517865a433b97ac57b7b52c823332 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 28 Jun 2024 14:52:01 +0000 Subject: [PATCH 074/156] feat: support path and dict for `config` argument --- src/anemoi/graphs/create.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index bf3cb0c..06b8f2b 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -26,6 +26,11 @@ def __init__( else: self.config = config + if isinstance(config, str) or isinstance(config, os.PathLike): + self.config = DotDict.from_file(self.config) + else: + self.config = config + self.path = path # Output path self.cache = cache self.print = print From 9768204b0f5238c4d6c2e9f1decfe9ae61c2bf1d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 28 Jun 2024 15:15:41 +0000 Subject: [PATCH 075/156] fix: error --- src/anemoi/graphs/create.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 06b8f2b..c4f1030 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -26,10 +26,6 @@ def __init__( else: self.config = config - if isinstance(config, str) or isinstance(config, os.PathLike): - self.config = DotDict.from_file(self.config) - else: - self.config = config self.path = path # Output path self.cache = cache From 74ef4ccdaed9235dafa2db1282e28a35e6e49bc7 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 10:12:40 +0000 Subject: [PATCH 076/156] refactor: naming --- src/anemoi/graphs/edges/__init__.py | 1 + src/anemoi/graphs/nodes/weights.py | 61 ----------------------------- tests/nodes/test_weights.py | 42 -------------------- 3 files changed, 1 insertion(+), 103 deletions(-) delete mode 100644 src/anemoi/graphs/nodes/weights.py delete mode 100644 tests/nodes/test_weights.py diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 53b9c74..19d48db 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -2,3 +2,4 @@ from .builder import KNNEdges __all__ = ["KNNEdges", "CutOffEdges"] + diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py deleted file mode 100644 index 25419cc..0000000 --- a/src/anemoi/graphs/nodes/weights.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging -from abc import ABC -from abc import abstractmethod -from typing import Optional - -import numpy as np -import torch -from scipy.spatial import SphericalVoronoi -from torch_geometric.data.storage import NodeStorage - -from anemoi.graphs.generate.transforms import to_sphere_xyz -from anemoi.graphs.normalizer import NormalizerMixin - -logger = logging.getLogger(__name__) - - -class BaseWeights(ABC, NormalizerMixin): - """Base class for the weights of the nodes.""" - - def __init__(self, norm: Optional[str] = None): - self.norm = norm - - @abstractmethod - def compute(self, nodes: NodeStorage, *args, **kwargs): ... - - def get_weights(self, *args, **kwargs) -> torch.Tensor: - weights = self.compute(*args, **kwargs) - if weights.ndim == 1: - weights = weights[:, np.newaxis] - norm_weights = self.normalize(weights) - return torch.tensor(norm_weights, dtype=torch.float32) - - -class UniformWeights(BaseWeights): - """Implements a uniform weight for the nodes.""" - - def compute(self, nodes: NodeStorage) -> np.ndarray: - return np.ones(nodes.num_nodes) - - -class AreaWeights(BaseWeights): - """Implements the area of the nodes as the weights.""" - - def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])): - super().__init__(norm=norm) - - # Weighting of the nodes - self.radius = radius - self.centre = centre - - def compute(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: - latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] - points = to_sphere_xyz((latitudes, longitudes)) - sv = SphericalVoronoi(points, self.radius, self.centre) - area_weights = sv.calculate_areas() - logger.debug( - "There are %d of weights, which (unscaled) add up a total weight of %.2f.", - len(area_weights), - np.array(area_weights).sum(), - ) - return area_weights diff --git a/tests/nodes/test_weights.py b/tests/nodes/test_weights.py deleted file mode 100644 index 71e54fa..0000000 --- a/tests/nodes/test_weights.py +++ /dev/null @@ -1,42 +0,0 @@ -import pytest -import torch -from torch_geometric.data import HeteroData - -from anemoi.graphs.nodes.weights import AreaWeights -from anemoi.graphs.nodes.weights import UniformWeights - - -@pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"]) -def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): - """Test NPZNodes register correctly the weights.""" - node_attr_builder = UniformWeights(norm=norm) - weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) - - assert weights is not None - assert isinstance(weights, torch.Tensor) - assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] - - -@pytest.mark.parametrize("norm", ["l3", "invalide"]) -def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): - """Test NPZNodes register correctly the weights.""" - with pytest.raises(ValueError): - node_attr_builder = UniformWeights(norm=norm) - node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) - - -def test_area_weights(graph_with_nodes: HeteroData): - """Test NPZNodes register correctly the weights.""" - node_attr_builder = AreaWeights() - weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) - - assert weights is not None - assert isinstance(weights, torch.Tensor) - assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0] - - -@pytest.mark.parametrize("radius", [-1.0, "hello", None]) -def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): - with pytest.raises(ValueError): - node_attr_builder = AreaWeights(radius=radius) - node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) From 91e15b3aeaee16e30814ed7a6f03cc3bbd3a6062 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 10:19:05 +0000 Subject: [PATCH 077/156] fix: pre-commit --- src/anemoi/graphs/edges/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 19d48db..53b9c74 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -2,4 +2,3 @@ from .builder import KNNEdges __all__ = ["KNNEdges", "CutOffEdges"] - From 66d3aa1e0ffa8c24a47c82d7f795726c5e6a3329 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 12:06:22 +0000 Subject: [PATCH 078/156] feat: builders icosahedral --- src/anemoi/graphs/nodes/builder.py | 52 ++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 6ff37a1..100f5d8 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -23,6 +23,9 @@ class BaseNodeBuilder(ABC): def __init__(self, name: str) -> None: self.name = name + def __init__(self) -> None: + self.aoi_mask_builder = None + def register_nodes(self, graph: HeteroData) -> None: """Register nodes in the graph. @@ -188,3 +191,52 @@ def get_coordinates(self) -> torch.Tensor: """ coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) return coords + + +class RefinedIcosahedralNodeBuilder(BaseNodeBuilder): + """Processor mesh based on a triangular mesh. + + It is based on the icosahedral mesh, which is a mesh of triangles that covers the sphere. + + Parameters + ---------- + resolution : list[int] | int + Refinement level of the mesh. + np_dtype : np.dtype, optional + The numpy data type to use, by default np.float32. + """ + + def __init__( + self, + resolution: Union[int, list[int]], + np_dtype: np.dtype = np.float32, + ) -> None: + self.np_dtype = np_dtype + + if isinstance(resolution, int): + self.resolutions = list(range(resolution + 1)) + else: + self.resolutions = resolution + + super().__init__() + + def get_coordinates(self) -> np.ndarray: + self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() + return coords_rad[self.node_ordering] + + def create_nodes(self) -> np.ndarray: ... + + def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: + graph[name]["resolutions"] = self.resolutions + graph[name]["nx_graph"] = self.nx_graph + graph[name]["node_ordering"] = self.node_ordering + graph[name]["aoi_mask_builder"] = self.aoi_mask_builder + return super().register_attributes(graph, name, config) + + +class TriRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): + """It depends on the trimesh Python library.""" + + def create_nodes(self) -> np.ndarray: + return create_icosahedral_nodes(resolutions=self.resolutions, aoi_nneighb=self.aoi_mask_builder) + From 23b311074843561e511645efb3a081c3204cd3cc Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 15:30:13 +0000 Subject: [PATCH 079/156] feat: Add icosahedral graph generation Co-authored-by: Mario Santa Cruz --- src/anemoi/graphs/generate/icosahedral.py | 216 ++++++++++++++++++++++ src/anemoi/graphs/nodes/builder.py | 7 +- 2 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 src/anemoi/graphs/generate/icosahedral.py diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py new file mode 100644 index 0000000..d28d7ef --- /dev/null +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -0,0 +1,216 @@ +from collections.abc import Iterable +from typing import Optional + +import networkx as nx +import numpy as np +import trimesh +from sklearn.metrics.pairwise import haversine_distances +from sklearn.neighbors import BallTree + +from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad +import logging + +logger = logging.getLogger(__name__) + + +def create_icosahedral_nodes( + resolutions: list[int], +) -> tuple[nx.DiGraph, np.ndarray, list[int]]: + """Creates a global mesh following AIFS strategy. + + This method relies on the trimesh python library. + + Parameters + ---------- + resolutions : list[int] + Levels of mesh resolution to consider. + aoi_mask_builder : KNNAreaMaskBuilder + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. + + Returns + ------- + graph : networkx.Graph + The specified graph (nodes & edges). + vertices_rad : np.ndarray + The vertices (not ordered) of the mesh in radians. + node_ordering : list[int] + Order of the nodes in the graph to be sorted by latitude and longitude. + """ + sphere = create_sphere(resolutions[-1]) + coords_rad = cartesian_to_latlon_rad(sphere.vertices) + + node_ordering = get_node_ordering(coords_rad) + + # TODO: AOI mask builder is not used in the current implementation. + + nx_graph = create_icosahedral_nx_graph_from_coords(coords_rad, node_ordering) + + return nx_graph, coords_rad, list(node_ordering) + + +def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_ordering: list[int]): + + graph = nx.DiGraph() + for ii, coords in enumerate(coords_rad[node_ordering]): + node_id = node_ordering[ii] + graph.add_node(node_id, hcoords_rad=coords) + + assert list(graph.nodes.keys()) == list(node_ordering), "Nodes are not correctly added to the graph." + assert graph.number_of_nodes() == len(node_ordering), "The number of nodes must be the same." + return graph + + +def get_node_ordering(vertices_rad: np.ndarray) -> np.ndarray: + # Get indices to sort points by lon & lat in radians. + ind1 = np.argsort(vertices_rad[:, 1]) + ind2 = np.argsort(vertices_rad[ind1][:, 0])[::-1] + node_ordering = np.arange(vertices_rad.shape[0])[ind1][ind2] + return node_ordering + + +def add_edges_to_nx_graph( + graph: nx.DiGraph, + resolutions: list[int], + xhops: int = 1, +) -> None: + """Adds the edges to the graph. + + Parameters + ---------- + graph : nx.DiGraph + The graph to add the edges. It should correspond to the mesh nodes, without edges. + resolutions : list[int] + Levels of mesh refinement levels to consider. + xhops : int, optional + Number of hops between 2 nodes to consider them neighbours, by default 1. + aoi_mask_builder : KNNAreaMaskBuilder + NearestNeighbors with the cloud of points to limit the mesh area, by default None. + margin_radius_km : float, optional + Margin radius in km to consider when creating the processor mesh, by default 0.0. + """ + assert xhops > 0, "xhops == 0, graph would have no edges ..." + + sphere = create_sphere(resolutions[-1]) + vertices_rad = cartesian_to_latlon_rad(sphere.vertices) + x_hops = get_x_hops(sphere, xhops, valid_nodes=list(graph.nodes)) + + for i, i_neighbours in x_hops.items(): + add_neigbours_edges(graph, vertices_rad, i, i_neighbours) + + tree = BallTree(vertices_rad, metric="haversine") + + for resolution in resolutions[:-1]: + # Defined refined sphere + r_sphere = create_sphere(resolution) + r_vertices_rad = cartesian_to_latlon_rad(r_sphere.vertices) + + # TODO AOI mask builder is not used in the current implementation. + valid_nodes = None + + x_rings = get_x_hops(r_sphere, xhops, valid_nodes=valid_nodes) + + _, idx = tree.query(r_vertices_rad, k=1) + for i, i_neighbours in x_rings.items(): + add_neigbours_edges(graph, r_vertices_rad, i, i_neighbours, idx=idx) + + return graph + + +def create_sphere(subdivisions: int = 0, radius: float = 1.0) -> trimesh.Trimesh: + """Creates a sphere. + + Parameters + ---------- + subdivisions : int, optional + How many times to subdivide the mesh. Note that the number of faces will grow as function of 4 ** subdivisions. + Defaults to 0. + radius : float, optional + Radius of the sphere created, by default 1.0 + + Returns + ------- + trimesh.Trimesh + Meshed sphere. + """ + return trimesh.creation.icosphere(subdivisions=subdivisions, radius=radius) + + +def get_x_hops(sp: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[int]] = None) -> dict[int, set[int]]: + """Get the neigbour connections in the graph. + + Parameters + ---------- + sp : trimesh.Trimesh + The mesh to consider. + hops : int + Number of hops between 2 nodes to consider them neighbours. + valid_nodes : list[int], optional + List of valid nodes to consider, by default None. It is useful to consider only a subset of the nodes to save + computation time. + + Returns + ------- + neighbours : dict[int, set[int]] + A list with the neighbours for each vertex. The element at position 'i' correspond to the neighbours to the + i-th vertex of the mesh. + """ + edges = sp.edges_unique + if valid_nodes is not None: + edges = edges[np.isin(sp.edges_unique, valid_nodes).all(axis=1)] + else: + valid_nodes = list(range(len(sp.vertices))) + g = nx.from_edgelist(edges) + + neighbours = {ii: set(nx.ego_graph(g, ii, radius=hops, center=False) if ii in g else []) for ii in valid_nodes} + + return neighbours + + +def add_neigbours_edges( + graph: nx.Graph, + vertices: np.ndarray, + ii: int, + neighbours: Iterable[int], + self_loops: bool = False, + idx: Optional[np.ndarray] = None, +) -> None: + """Adds the edges of one node to its neighbours. + + Parameters + ---------- + graph : nx.Graph + The graph. + vertices : np.ndarray + A 2D array of shape (num_vertices, 2) with the planar coordinates of the mesh, in radians. + ii : int + The node considered. + neighbours : list[int] + The neighbours of the node. + self_loops : bool, optional + Whether is supported to add self-loops, by default False. + idx : np.ndarray, optional + Index to map the vertices from the refined sphere to the original one, by default None. + """ + for ineighb in neighbours: + if not self_loops and ii == ineighb: # no self-loops + continue + + loc_self = vertices[ii] + loc_neigh = vertices[ineighb] + edge_length = haversine_distances([loc_neigh, loc_self])[0][1] + + if idx is not None: + # Use the same method to add edge in all spheres + node_neigh = idx[ineighb][0] + node = idx[ii][0] + else: + node, node_neigh = ii, ineighb + + # add edge to the graph + if node in graph and node_neigh in graph: + graph.add_edge(node_neigh, node, weight=edge_length) + + + + + diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 100f5d8..9ee7c37 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -3,6 +3,7 @@ from abc import abstractmethod from pathlib import Path from typing import Optional +from typing import Union import numpy as np import torch @@ -10,6 +11,7 @@ from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch_geometric.data import HeteroData +from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes LOGGER = logging.getLogger(__name__) @@ -230,7 +232,7 @@ def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> graph[name]["resolutions"] = self.resolutions graph[name]["nx_graph"] = self.nx_graph graph[name]["node_ordering"] = self.node_ordering - graph[name]["aoi_mask_builder"] = self.aoi_mask_builder + # TODO: AOI mask builder is not used in the current implementation. return super().register_attributes(graph, name, config) @@ -238,5 +240,6 @@ class TriRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): """It depends on the trimesh Python library.""" def create_nodes(self) -> np.ndarray: - return create_icosahedral_nodes(resolutions=self.resolutions, aoi_nneighb=self.aoi_mask_builder) + # TODO: AOI mask builder is not used in the current implementation. + return create_icosahedral_nodes(resolutions=self.resolutions) From 469b56b25db97525bdb684e9b12bc741f2e2869e Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 15:38:13 +0000 Subject: [PATCH 080/156] refactor: remove create_shere --- src/anemoi/graphs/generate/icosahedral.py | 54 +++++++------------ src/anemoi/graphs/nodes/nodes.py | 65 +++++++++++++++++++++++ 2 files changed, 83 insertions(+), 36 deletions(-) create mode 100644 src/anemoi/graphs/nodes/nodes.py diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index d28d7ef..97c4801 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -36,7 +36,8 @@ def create_icosahedral_nodes( node_ordering : list[int] Order of the nodes in the graph to be sorted by latitude and longitude. """ - sphere = create_sphere(resolutions[-1]) + sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) + coords_rad = cartesian_to_latlon_rad(sphere.vertices) node_ordering = get_node_ordering(coords_rad) @@ -90,7 +91,7 @@ def add_edges_to_nx_graph( """ assert xhops > 0, "xhops == 0, graph would have no edges ..." - sphere = create_sphere(resolutions[-1]) + sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) vertices_rad = cartesian_to_latlon_rad(sphere.vertices) x_hops = get_x_hops(sphere, xhops, valid_nodes=list(graph.nodes)) @@ -101,7 +102,7 @@ def add_edges_to_nx_graph( for resolution in resolutions[:-1]: # Defined refined sphere - r_sphere = create_sphere(resolution) + r_sphere = trimesh.creation.icosphere(subdivisions=resolution, radius=1.0) r_vertices_rad = cartesian_to_latlon_rad(r_sphere.vertices) # TODO AOI mask builder is not used in the current implementation. @@ -116,31 +117,12 @@ def add_edges_to_nx_graph( return graph -def create_sphere(subdivisions: int = 0, radius: float = 1.0) -> trimesh.Trimesh: - """Creates a sphere. - - Parameters - ---------- - subdivisions : int, optional - How many times to subdivide the mesh. Note that the number of faces will grow as function of 4 ** subdivisions. - Defaults to 0. - radius : float, optional - Radius of the sphere created, by default 1.0 - - Returns - ------- - trimesh.Trimesh - Meshed sphere. - """ - return trimesh.creation.icosphere(subdivisions=subdivisions, radius=radius) - - -def get_x_hops(sp: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[int]] = None) -> dict[int, set[int]]: +def get_x_hops(tri_mesh: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[int]] = None) -> dict[int, set[int]]: """Get the neigbour connections in the graph. Parameters ---------- - sp : trimesh.Trimesh + tri_mesh : trimesh.Trimesh The mesh to consider. hops : int Number of hops between 2 nodes to consider them neighbours. @@ -154,11 +136,11 @@ def get_x_hops(sp: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[int]] A list with the neighbours for each vertex. The element at position 'i' correspond to the neighbours to the i-th vertex of the mesh. """ - edges = sp.edges_unique + edges = tri_mesh.edges_unique if valid_nodes is not None: - edges = edges[np.isin(sp.edges_unique, valid_nodes).all(axis=1)] + edges = edges[np.isin(tri_mesh.edges_unique, valid_nodes).all(axis=1)] else: - valid_nodes = list(range(len(sp.vertices))) + valid_nodes = list(range(len(tri_mesh.vertices))) g = nx.from_edgelist(edges) neighbours = {ii: set(nx.ego_graph(g, ii, radius=hops, center=False) if ii in g else []) for ii in valid_nodes} @@ -191,24 +173,24 @@ def add_neigbours_edges( idx : np.ndarray, optional Index to map the vertices from the refined sphere to the original one, by default None. """ - for ineighb in neighbours: - if not self_loops and ii == ineighb: # no self-loops + for idx_neighbour in neighbours: + if not self_loops and ii == idx_neighbour: # no self-loops continue - loc_self = vertices[ii] - loc_neigh = vertices[ineighb] - edge_length = haversine_distances([loc_neigh, loc_self])[0][1] + location_node = vertices[ii] + location_neighbour = vertices[idx_neighbour] + edge_length = haversine_distances([location_neighbour, location_node])[0][1] if idx is not None: # Use the same method to add edge in all spheres - node_neigh = idx[ineighb][0] + node_neighbour = idx[idx_neighbour][0] node = idx[ii][0] else: - node, node_neigh = ii, ineighb + node, node_neighbour = ii, idx_neighbour # add edge to the graph - if node in graph and node_neigh in graph: - graph.add_edge(node_neigh, node, weight=edge_length) + if node in graph and node_neighbour in graph: + graph.add_edge(node_neighbour, node, weight=edge_length) diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py new file mode 100644 index 0000000..3d59e5f --- /dev/null +++ b/src/anemoi/graphs/nodes/nodes.py @@ -0,0 +1,65 @@ +import logging +from abc import ABC +from abc import abstractmethod +from pathlib import Path + +import numpy as np +import torch +from anemoi.datasets import open_dataset +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + +logger = logging.getLogger(__name__) + + +class BaseNodeBuilder(ABC): + """Base class for node builders.""" + + def register_nodes(self, graph: HeteroData, name: str) -> None: + graph[name].x = self.get_coordinates() + graph[name].node_type = type(self).__name__ + return graph + + def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: + for nodes_attr_name, attr_cfg in config.items(): + graph[name][nodes_attr_name] = instantiate(attr_cfg).get_weights(graph[name]) + return graph + + @abstractmethod + def get_coordinates(self) -> np.ndarray: ... + + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: + coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) + coords = np.deg2rad(coords) + return torch.tensor(coords, dtype=torch.float32) + + def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: + graph = self.register_nodes(graph, name) + graph = self.register_attributes(graph, name, attr_config) + return graph + + + +class ZarrNodes(BaseNodeBuilder): + """Nodes from Zarr dataset.""" + + def __init__(self, dataset: DotDict) -> None: + logger.info("Reading the dataset from %s.", dataset) + self.ds = open_dataset(dataset) + + def get_coordinates(self) -> torch.Tensor: + return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) + + +class NPZNodes(BaseNodeBuilder): + """Nodes from NPZ defined grids.""" + + def __init__(self, resolution: str, grid_definition_path: str) -> None: + self.resolution = resolution + self.grid_definition_path = grid_definition_path + self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + + def get_coordinates(self) -> np.ndarray: + coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) + return coords From 4e864c32f0f0fbbc6eb9fff3d85b65d2cfc4a2a3 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 15:46:35 +0000 Subject: [PATCH 081/156] feat: Icosahedral edge builder --- src/anemoi/graphs/edges/builder.py | 45 ++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 17ba4fc..81b4709 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -5,6 +5,7 @@ import numpy as np import torch +import networkx as nx from anemoi.utils.config import DotDict from hydra.utils import instantiate from sklearn.neighbors import NearestNeighbors @@ -13,6 +14,8 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.utils import get_grid_reference_distance +from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder +from anemoi.graphs.generate import icosahedral LOGGER = logging.getLogger(__name__) @@ -260,3 +263,45 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor nearest_neighbour.fit(source_nodes.x) adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo() return adj_matrix + + +class TriIcosahedralEdgeBuilder(BaseEdgeBuilder): + """Computes icosahedral edges and adds them to a HeteroData graph.""" + + def __init__(self, src_name: str, dst_name: str, xhops: int): + super().__init__(src_name, dst_name) + + assert isinstance(xhops, int), "Number of xhops must be an integer" + assert xhops > 0, "Number of xhops must be positive" + + self.xhops = xhops + + def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: + + assert ( + graph[self.src_name].node_type == TriRefinedIcosahedralNodeBuilder.__name__ + ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." + assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination." + + # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate + # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." + + return super().transform(graph, edge_name, attrs_config) + + def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): + + src_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph( + src_nodes["nx_graph"], + resolutions=src_nodes["resolutions"], + xhops=self.xhops, + aoi_nneighb=None if "aoi_nneighb" not in src_nodes else src_nodes["aoi_nneigh"], + ) # HeteroData refuses to accept None + + adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], nodelist=list(src_nodes["nx_graph"]), format="coo") + graph_1_sorted = dict(zip(range(len(src_nodes["nx_graph"].nodes)), list(src_nodes["nx_graph"].nodes))) + graph_2_sorted = dict(zip(src_nodes.node_ordering, range(len(src_nodes.node_ordering)))) + sort_func1 = np.vectorize(graph_1_sorted.get) + sort_func2 = np.vectorize(graph_2_sorted.get) + adjmat.row = sort_func2(sort_func1(adjmat.row)) + adjmat.col = sort_func2(sort_func1(adjmat.col)) + return adjmat From d57390539f223debc297a30cd162a23c5867301b Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 16:08:38 +0000 Subject: [PATCH 082/156] feat: hexagonal graph generation Co-authored-by: Mario Santa Cruz --- src/anemoi/graphs/generate/hexagonal.py | 260 ++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 src/anemoi/graphs/generate/hexagonal.py diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py new file mode 100644 index 0000000..902d88b --- /dev/null +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -0,0 +1,260 @@ +from typing import Optional + +import h3 +import networkx as nx +import numpy as np +import torch +from sklearn.metrics.pairwise import haversine_distances + + +def add_edge( + graph: nx.Graph, + idx1: str, + idx2: str, + allow_self_loop: bool = False, +) -> None: + """Add edge between two nodes to a graph. + + The edge will only be added in case both tail and head nodes are included in the graph, G. + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + idx1 : str + The H3 index of the tail of the edge. + idx2 : str + The H3 index of the head of the edge. + allow_self_loop : bool + Whether to allow self-loops or not. Defaults to not allowing self-loops. + """ + if not graph.has_node(idx1) or not graph.has_node(idx2): + return + + if allow_self_loop or idx1 != idx2: + loc1 = np.deg2rad(h3.h3_to_geo(idx1)) + loc2 = np.deg2rad(h3.h3_to_geo(idx2)) + graph.add_edge(idx1, idx2, weight=haversine_distances([loc1, loc2])[0][1]) + + +def get_cells_at_resolution( + resolution: int, + area: Optional[dict] = None, + aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, +) -> set[str]: + """Get cells at a specified refinement level. + + Parameters + ---------- + resolution : int + The H3 refinement level. It can be an integer from 0 to 15. + area : dict + A region, in GeoJSON data format, to be contained by all cells. Defaults to None. + aoi_mask_builder : KNNAreaMaskBuilder, optional + KNNAreaMaskBuilder computes nask to limit the mesh area, by default None. + + Returns + ------- + cells : set[str] + The set of H3 indexes at the specified resolution level. + """ + # TODO: What is area? + cells = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) + + if aoi_mask_builder is not None: + cells = list(cells) + + coords = np.deg2rad(np.array([h3.h3_to_geo(c) for c in cells])) + aoi_mask = aoi_mask_builder.get_mask(coords) + + cells = set(map(str, np.array(cells)[aoi_mask])) + + return cells + + +def add_nodes_for_resolution( + graph: nx.Graph, + resolution: int, + self_loop: bool = False, + **area_kwargs: Optional[dict], +) -> None: + """Add all nodes at a specified refinement level to a graph. + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + resolution : int + The H3 refinement level. It can be an integer from 0 to 15. + self_loop : int + Whether to include self-loops in the nodes added or not. + area_kwargs: dict + Additional arguments to pass to the get_cells_at_resolution function. + """ + for idx in get_cells_at_resolution(resolution, **area_kwargs): + graph.add_node(idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(idx))) + if self_loop: + # TODO: should that be add_self_loops(graph)? + add_edge(graph, idx, idx, allow_self_loop=self_loop) + + +def add_neighbour_edges( + graph: nx.Graph, + refinement_levels: tuple[int], + flat: bool = True, +) -> None: + for resolution in refinement_levels: + cells = {node for node in graph.nodes if h3.h3_get_resolution(node) == resolution} + for idx in cells: + k = 2 if resolution == 0 else 1 # refinement_levels[0]: # extra large field of vision ; only few nodes + + # neighbours + for idx_neighbour in h3.k_ring(idx, k=k) & cells: + if flat: + add_edge( + graph, + h3.h3_to_center_child(idx, refinement_levels[-1]), + h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]), + ) + else: + add_edge(graph, idx, idx_neighbour) + + +def create_hexagonal_nodes( + resolutions: list[int], + flat: bool = True, + area: Optional[dict] = None, + aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, +) -> tuple[nx.Graph, torch.Tensor, list[int]]: + """Creates a global mesh from a refined icosahedro. + + This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each + refinement level, a hexagon cell has 7 child cells (aperture 7). + + Parameters + ---------- + resolutions : list[int] + Levels of mesh resolution to consider. + flat : bool + Whether or not all resolution levels of the mesh are included. + area : dict + A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global + mesh. + aoi_mask_builder : KNNAreaMaskBuilder, optional + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. + + Returns + ------- + graph : networkx.Graph + The specified graph (nodes & edges). + """ + graph = nx.Graph() + + area_kwargs = {"area": area, "aoi_mask_builder": aoi_mask_builder} + + for resolution in resolutions: + add_nodes_for_resolution(graph, resolution, **area_kwargs) + + coords = np.array([h3.h3_to_geo(node) for node in graph.nodes]) + + # Sort nodes by latitude and longitude + node_ordering = np.lexsort(coords.T[::-1], axis=0) + + # Should these be sorted here or in the edge builder? + coords = coords[node_ordering] + + return graph, coords, node_ordering + + +def add_edges_to_nx_graph( + graph: nx.Graph, + resolutions: list[int], + self_loop: bool = False, + flat: bool = True, + neighbour_children: bool = False, + depth_children: int = 1, +) -> nx.Graph: + """Creates a global mesh from a refined icosahedro. + + This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each + refinement level, a hexagon cell has 7 child cells (aperture 7). + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + resolutions : list[int] + Levels of mesh resolution to consider. + self_loop : bool + Whether include a self-loop in every node or not. + flat : bool + Whether or not all resolution levels of the mesh are included. + neighbour_children : bool + Whether to include connections with the children from the neighbours. + depth_children : int + The number of resolution levels to consider for the connections of children. Defaults to 1, which includes + connections up to the next resolution level. + + Returns + ------- + graph : networkx.Graph + The specified graph (nodes & edges). + """ + if self_loop: + add_self_loops(graph) + + add_neighbour_edges(graph, resolutions, flat) + add_children_edges( + graph, + resolutions, + flat, + neighbour_children, + depth_children, + ) + return graph + + +def add_self_loops(graph: nx.Graph) -> None: + + for idx in graph.nodes: + add_edge(graph, idx, idx, allow_self_loop=True) + + +def add_children_edges( + graph: nx.Graph, + refinement_levels: tuple[int], + flat: bool = True, + neighbour_children: bool = False, + depth: Optional[int] = None, +) -> None: + if depth is None: + depth = len(refinement_levels) + + for ip, resolution_parent in enumerate(refinement_levels[0:-1]): + parent_cells = [node for node in graph.nodes if h3.h3_get_resolution(node) == resolution_parent] + for idx_parent in parent_cells: + # add own children + for resolution_child in refinement_levels[ip + 1 : ip + depth + 1]: + for idx_child in h3.h3_to_children(idx_parent, res=resolution_child): + if flat: + add_edge( + graph, + h3.h3_to_center_child(idx_parent, refinement_levels[-1]), + h3.h3_to_center_child(idx_child, refinement_levels[-1]), + ) + else: + add_edge(graph, idx_parent, idx_child) + + # add neighbour children + if neighbour_children: + for idx_parent_neighbour in h3.k_ring(idx_parent, k=1) & parent_cells: + for resolution_child in refinement_levels[ip + 1 : ip + depth + 1]: + for idx_child_neighbour in h3.h3_to_children(idx_parent_neighbour, res=resolution_child): + if flat: + add_edge( + graph, + h3.h3_to_center_child(idx_parent, refinement_levels[-1]), + h3.h3_to_center_child(idx_child_neighbour, refinement_levels[-1]), + ) + else: + add_edge(graph, idx_parent, idx_child_neighbour) From edd9c437b1920408cf82b69d1b86ba833f93e019 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 16:10:01 +0000 Subject: [PATCH 083/156] feat: hexagonal builders --- src/anemoi/graphs/edges/builder.py | 37 ++++++++++++++++++++++++++++++ src/anemoi/graphs/nodes/builder.py | 7 ++++++ 2 files changed, 44 insertions(+) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 81b4709..029a86f 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -15,6 +15,7 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.utils import get_grid_reference_distance from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder +from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder from anemoi.graphs.generate import icosahedral LOGGER = logging.getLogger(__name__) @@ -305,3 +306,39 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): adjmat.row = sort_func2(sort_func1(adjmat.row)) adjmat.col = sort_func2(sort_func1(adjmat.col)) return adjmat + + +class HexagonalEdgeBuilder(BaseEdgeBuilder): + """Computes hexagonal edges and adds them to a HeteroData graph.""" + + def __init__(self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): + super().__init__(src_name, dst_name) + self.add_neighbouring_children = add_neighbouring_children + self.depth_children = depth_children + + def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: + assert ( + graph[self.src_name].node_type == HexRefinedIcosahedralNodeBuilder.__name__ + ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." + assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination." + + # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate + # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." + + return super().transform(graph, edge_name, attrs_config) + + def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): + + src_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph( + src_nodes["nx_graph"], + resolutions=src_nodes["resolutions"], + neighbour_children=self.add_neighbouring_children, + depth_children=self.depth_children, + ) + + adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], format="coo") + graph_2_sorted = dict(zip(src_nodes["node_ordering"], range(len(src_nodes.node_ordering)))) + sort_func = np.vectorize(graph_2_sorted.get) + adjmat.row = sort_func(adjmat.row) + adjmat.col = sort_func(adjmat.col) + return adjmat \ No newline at end of file diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 9ee7c37..869ee84 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -12,6 +12,7 @@ from hydra.utils import instantiate from torch_geometric.data import HeteroData from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes +from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes LOGGER = logging.getLogger(__name__) @@ -243,3 +244,9 @@ def create_nodes(self) -> np.ndarray: # TODO: AOI mask builder is not used in the current implementation. return create_icosahedral_nodes(resolutions=self.resolutions) + +class HexRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): + """It depends on the h3 Python library.""" + + def create_nodes(self) -> np.ndarray: + return create_hexagonal_nodes(self.resolutions) From b7002e375f6844697788733756eab9525cead70d Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 28 Jun 2024 16:13:05 +0000 Subject: [PATCH 084/156] fix: AOI not implemented yet --- src/anemoi/graphs/generate/hexagonal.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 902d88b..66e18f6 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -40,7 +40,6 @@ def add_edge( def get_cells_at_resolution( resolution: int, area: Optional[dict] = None, - aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, ) -> set[str]: """Get cells at a specified refinement level. @@ -61,13 +60,7 @@ def get_cells_at_resolution( # TODO: What is area? cells = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) - if aoi_mask_builder is not None: - cells = list(cells) - - coords = np.deg2rad(np.array([h3.h3_to_geo(c) for c in cells])) - aoi_mask = aoi_mask_builder.get_mask(coords) - - cells = set(map(str, np.array(cells)[aoi_mask])) + # TODO: AOI not used in the current implementation. return cells @@ -124,7 +117,6 @@ def create_hexagonal_nodes( resolutions: list[int], flat: bool = True, area: Optional[dict] = None, - aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, ) -> tuple[nx.Graph, torch.Tensor, list[int]]: """Creates a global mesh from a refined icosahedro. @@ -150,7 +142,7 @@ def create_hexagonal_nodes( """ graph = nx.Graph() - area_kwargs = {"area": area, "aoi_mask_builder": aoi_mask_builder} + area_kwargs = {"area": area} for resolution in resolutions: add_nodes_for_resolution(graph, resolution, **area_kwargs) From 1d55ef698958d9aa441c966e8b63a3c11d853b51 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 11:06:54 +0000 Subject: [PATCH 085/156] fix: abstractmethod and renaming --- src/anemoi/graphs/edges/builder.py | 27 +++++++++++++++++---------- src/anemoi/graphs/nodes/builder.py | 12 +++++++----- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 029a86f..8904fb6 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -3,9 +3,9 @@ from abc import abstractmethod from typing import Optional +import networkx as nx import numpy as np import torch -import networkx as nx from anemoi.utils.config import DotDict from hydra.utils import instantiate from sklearn.neighbors import NearestNeighbors @@ -13,10 +13,11 @@ from torch_geometric.data.storage import NodeStorage from anemoi.graphs import EARTH_RADIUS -from anemoi.graphs.utils import get_grid_reference_distance -from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder -from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder +from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral +from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder +from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder +from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) @@ -266,7 +267,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor return adj_matrix -class TriIcosahedralEdgeBuilder(BaseEdgeBuilder): +class TriIcosahedralEdges(BaseEdgeBuilder): """Computes icosahedral edges and adds them to a HeteroData graph.""" def __init__(self, src_name: str, dst_name: str, xhops: int): @@ -282,7 +283,9 @@ def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[Do assert ( graph[self.src_name].node_type == TriRefinedIcosahedralNodeBuilder.__name__ ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." - assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination." + assert ( + graph[self.src_name] == graph[self.dst_name] + ), "InheritConnection requires the same nodes for source and destination." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." @@ -308,10 +311,12 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): return adjmat -class HexagonalEdgeBuilder(BaseEdgeBuilder): +class HexagonalEdges(BaseEdgeBuilder): """Computes hexagonal edges and adds them to a HeteroData graph.""" - def __init__(self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): + def __init__( + self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1 + ): super().__init__(src_name, dst_name) self.add_neighbouring_children = add_neighbouring_children self.depth_children = depth_children @@ -320,7 +325,9 @@ def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[Do assert ( graph[self.src_name].node_type == HexRefinedIcosahedralNodeBuilder.__name__ ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." - assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination." + assert ( + graph[self.src_name] == graph[self.dst_name] + ), "InheritConnection requires the same nodes for source and destination." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." @@ -341,4 +348,4 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): sort_func = np.vectorize(graph_2_sorted.get) adjmat.row = sort_func(adjmat.row) adjmat.col = sort_func(adjmat.col) - return adjmat \ No newline at end of file + return adjmat diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 869ee84..414ecc5 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -11,8 +11,9 @@ from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch_geometric.data import HeteroData -from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes + from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes +from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes LOGGER = logging.getLogger(__name__) @@ -196,7 +197,7 @@ def get_coordinates(self) -> torch.Tensor: return coords -class RefinedIcosahedralNodeBuilder(BaseNodeBuilder): +class RefinedIcosahedralNodes(BaseNodeBuilder, ABC): """Processor mesh based on a triangular mesh. It is based on the icosahedral mesh, which is a mesh of triangles that covers the sphere. @@ -227,6 +228,7 @@ def get_coordinates(self) -> np.ndarray: self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() return coords_rad[self.node_ordering] + @abstractmethod def create_nodes(self) -> np.ndarray: ... def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: @@ -237,15 +239,15 @@ def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> return super().register_attributes(graph, name, config) -class TriRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): +class TriRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the trimesh Python library.""" def create_nodes(self) -> np.ndarray: # TODO: AOI mask builder is not used in the current implementation. - return create_icosahedral_nodes(resolutions=self.resolutions) + return create_icosahedral_nodes(resolutions=self.resolutions) -class HexRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder): +class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the h3 Python library.""" def create_nodes(self) -> np.ndarray: From 7dcfb9873e5bee5614d20a6a8a70070869033967 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 12:55:02 +0000 Subject: [PATCH 086/156] chore: add dependencies --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index cb5bb7f..6654b59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,12 @@ dynamic = [ dependencies = [ "anemoi-datasets[data]>=0.3.3", "anemoi-utils>=0.3.6", + "h3>=3.7.6", "hydra-core>=1.3", + "networkx>=3.1", "torch>=2.2", "torch-geometric>=2.3.1,<2.5", + "trimesh>=4.1", ] optional-dependencies.all = [ From a35b76fef4ff0a6ad0b7a686e0f071165897f6b8 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 13:49:07 +0000 Subject: [PATCH 087/156] test: add tests for trimesh --- tests/nodes/test_tri_refined_icosahedral.py | 33 +++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/nodes/test_tri_refined_icosahedral.py diff --git a/tests/nodes/test_tri_refined_icosahedral.py b/tests/nodes/test_tri_refined_icosahedral.py new file mode 100644 index 0000000..762efdf --- /dev/null +++ b/tests/nodes/test_tri_refined_icosahedral.py @@ -0,0 +1,33 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes import builder + + +@pytest.mark.parametrize("resolution", [0, 2]) +def test_init(resolution: list[int]): + """Test TrirefinedIcosahedralNodes initialization.""" + + node_builder = builder.TriRefinedIcosahedralNodes(resolution) + assert isinstance(node_builder, builder.BaseNodeBuilder) + assert isinstance(node_builder, builder.TriRefinedIcosahedralNodes) + + +def test_get_coordinates(): + """Test get_coordinates method.""" + node_builder = builder.TriRefinedIcosahedralNodes(2) + coords = node_builder.get_coordinates() + assert isinstance(coords, torch.Tensor) + assert coords.shape == (162, 2) + + +def test_transform(): + """Test transform method.""" + node_builder = builder.TriRefinedIcosahedralNodes(1) + graph = HeteroData() + graph = node_builder.transform(graph, "test", {}) + assert "resolutions" in graph["test"] + assert "nx_graph" in graph["test"] + assert "node_ordering" in graph["test"] + assert len(graph["test"]["node_ordering"]) == graph["test"].num_nodes From bfc3e8467d86e0443d3c51d8329e251d4be446f2 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:02:05 +0000 Subject: [PATCH 088/156] test: add tests for hex (h3) --- tests/nodes/test_hex_refined_icosahedral.py | 33 +++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/nodes/test_hex_refined_icosahedral.py diff --git a/tests/nodes/test_hex_refined_icosahedral.py b/tests/nodes/test_hex_refined_icosahedral.py new file mode 100644 index 0000000..df0e716 --- /dev/null +++ b/tests/nodes/test_hex_refined_icosahedral.py @@ -0,0 +1,33 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes import builder + + +@pytest.mark.parametrize("resolution", [0, 2]) +def test_init(resolution: list[int]): + """Test TrirefinedIcosahedralNodes initialization.""" + + node_builder = builder.HexRefinedIcosahedralNodes(resolution) + assert isinstance(node_builder, builder.BaseNodeBuilder) + assert isinstance(node_builder, builder.HexRefinedIcosahedralNodes) + + +def test_get_coordinates(): + """Test get_coordinates method.""" + node_builder = builder.HexRefinedIcosahedralNodes(0) + coords = node_builder.get_coordinates() + assert isinstance(coords, torch.Tensor) + assert coords.shape == (122, 2) + + +def test_transform(): + """Test transform method.""" + node_builder = builder.HexRefinedIcosahedralNodes(0) + graph = HeteroData() + graph = node_builder.transform(graph, "test", {}) + assert "resolutions" in graph["test"] + assert "nx_graph" in graph["test"] + assert "node_ordering" in graph["test"] + assert len(graph["test"]["node_ordering"]) == graph["test"].num_nodes From fc2f707f20e999b6f89f8fbdb954bc6af433351f Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:03:03 +0000 Subject: [PATCH 089/156] fix: imports --- src/anemoi/graphs/nodes/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 737f27f..ef13e41 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,6 @@ +from .builder import HexRefinedIcosahedralNodes from .builder import NPZFileNodes +from .builder import TriRefinedIcosahedralNodes from .builder import ZarrDatasetNodes -__all__ = ["ZarrDatasetNodes", "NPZFileNodes"] +__all__ = ["ZarrDatasetNodes", "NPZFileNodes", "TriRefinedIcosahedralNodes", "HexRefinedIcosahedralNodes"] From 7c3dca37d8ba363e7745d5eb3b1466d6fd46b76c Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:03:44 +0000 Subject: [PATCH 090/156] fix: output type --- src/anemoi/graphs/nodes/builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 414ecc5..dbe3bfa 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -215,6 +215,7 @@ def __init__( resolution: Union[int, list[int]], np_dtype: np.dtype = np.float32, ) -> None: + # TODO: Discuss np_dtype self.np_dtype = np_dtype if isinstance(resolution, int): @@ -224,9 +225,9 @@ def __init__( super().__init__() - def get_coordinates(self) -> np.ndarray: + def get_coordinates(self) -> torch.Tensor: self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() - return coords_rad[self.node_ordering] + return torch.tensor(coords_rad[self.node_ordering]) @abstractmethod def create_nodes(self) -> np.ndarray: ... @@ -251,4 +252,5 @@ class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the h3 Python library.""" def create_nodes(self) -> np.ndarray: + # TODO: AOI mask builder is not used in the current implementation. return create_hexagonal_nodes(self.resolutions) From d627e5878022af70fdab64a3279b5a090b5436a6 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:06:19 +0000 Subject: [PATCH 091/156] refactor: delete unused file --- src/anemoi/graphs/nodes/nodes.py | 65 -------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 src/anemoi/graphs/nodes/nodes.py diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/nodes.py deleted file mode 100644 index 3d59e5f..0000000 --- a/src/anemoi/graphs/nodes/nodes.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -from abc import ABC -from abc import abstractmethod -from pathlib import Path - -import numpy as np -import torch -from anemoi.datasets import open_dataset -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from torch_geometric.data import HeteroData - -logger = logging.getLogger(__name__) - - -class BaseNodeBuilder(ABC): - """Base class for node builders.""" - - def register_nodes(self, graph: HeteroData, name: str) -> None: - graph[name].x = self.get_coordinates() - graph[name].node_type = type(self).__name__ - return graph - - def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: - for nodes_attr_name, attr_cfg in config.items(): - graph[name][nodes_attr_name] = instantiate(attr_cfg).get_weights(graph[name]) - return graph - - @abstractmethod - def get_coordinates(self) -> np.ndarray: ... - - def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: - coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) - coords = np.deg2rad(coords) - return torch.tensor(coords, dtype=torch.float32) - - def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData: - graph = self.register_nodes(graph, name) - graph = self.register_attributes(graph, name, attr_config) - return graph - - - -class ZarrNodes(BaseNodeBuilder): - """Nodes from Zarr dataset.""" - - def __init__(self, dataset: DotDict) -> None: - logger.info("Reading the dataset from %s.", dataset) - self.ds = open_dataset(dataset) - - def get_coordinates(self) -> torch.Tensor: - return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) - - -class NPZNodes(BaseNodeBuilder): - """Nodes from NPZ defined grids.""" - - def __init__(self, resolution: str, grid_definition_path: str) -> None: - self.resolution = resolution - self.grid_definition_path = grid_definition_path - self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") - - def get_coordinates(self) -> np.ndarray: - coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) - return coords From 735fbc1136c2e98b8075d9efd70cfefe11fe1154 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 15:50:54 +0000 Subject: [PATCH 092/156] refactor: renaming and positioning --- src/anemoi/graphs/generate/hexagonal.py | 199 +++++++++++----------- src/anemoi/graphs/generate/icosahedral.py | 41 ++--- 2 files changed, 119 insertions(+), 121 deletions(-) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 66e18f6..4ab9569 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -7,62 +7,49 @@ from sklearn.metrics.pairwise import haversine_distances -def add_edge( - graph: nx.Graph, - idx1: str, - idx2: str, - allow_self_loop: bool = False, -) -> None: - """Add edge between two nodes to a graph. - - The edge will only be added in case both tail and head nodes are included in the graph, G. - - Parameters - ---------- - graph : networkx.Graph - The graph to add the nodes. - idx1 : str - The H3 index of the tail of the edge. - idx2 : str - The H3 index of the head of the edge. - allow_self_loop : bool - Whether to allow self-loops or not. Defaults to not allowing self-loops. - """ - if not graph.has_node(idx1) or not graph.has_node(idx2): - return - - if allow_self_loop or idx1 != idx2: - loc1 = np.deg2rad(h3.h3_to_geo(idx1)) - loc2 = np.deg2rad(h3.h3_to_geo(idx2)) - graph.add_edge(idx1, idx2, weight=haversine_distances([loc1, loc2])[0][1]) - - -def get_cells_at_resolution( - resolution: int, +def create_hexagonal_nodes( + resolutions: list[int], + flat: bool = True, area: Optional[dict] = None, -) -> set[str]: - """Get cells at a specified refinement level. +) -> tuple[nx.Graph, torch.Tensor, list[int]]: + """Creates a global mesh from a refined icosahedro. + + This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each + refinement level, a hexagon cell has 7 child cells (aperture 7). Parameters ---------- - resolution : int - The H3 refinement level. It can be an integer from 0 to 15. + resolutions : list[int] + Levels of mesh resolution to consider. + flat : bool + Whether or not all resolution levels of the mesh are included. area : dict - A region, in GeoJSON data format, to be contained by all cells. Defaults to None. + A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global + mesh. aoi_mask_builder : KNNAreaMaskBuilder, optional - KNNAreaMaskBuilder computes nask to limit the mesh area, by default None. + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. Returns ------- - cells : set[str] - The set of H3 indexes at the specified resolution level. + graph : networkx.Graph + The specified graph (nodes & edges). """ - # TODO: What is area? - cells = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) + graph = nx.Graph() - # TODO: AOI not used in the current implementation. + area_kwargs = {"area": area} - return cells + for resolution in resolutions: + add_nodes_for_resolution(graph, resolution, **area_kwargs) + + coords = np.array([h3.h3_to_geo(node) for node in graph.nodes]) + + # Sort nodes by latitude and longitude + node_ordering = np.lexsort(coords.T[::-1], axis=0) + + # Should these be sorted here or in the edge builder? + coords = coords[node_ordering] + + return graph, coords, node_ordering def add_nodes_for_resolution( @@ -84,78 +71,42 @@ def add_nodes_for_resolution( area_kwargs: dict Additional arguments to pass to the get_cells_at_resolution function. """ - for idx in get_cells_at_resolution(resolution, **area_kwargs): + + cells = get_cells_at_resolution(resolution, **area_kwargs) + + for idx in cells: graph.add_node(idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(idx))) if self_loop: # TODO: should that be add_self_loops(graph)? add_edge(graph, idx, idx, allow_self_loop=self_loop) -def add_neighbour_edges( - graph: nx.Graph, - refinement_levels: tuple[int], - flat: bool = True, -) -> None: - for resolution in refinement_levels: - cells = {node for node in graph.nodes if h3.h3_get_resolution(node) == resolution} - for idx in cells: - k = 2 if resolution == 0 else 1 # refinement_levels[0]: # extra large field of vision ; only few nodes - - # neighbours - for idx_neighbour in h3.k_ring(idx, k=k) & cells: - if flat: - add_edge( - graph, - h3.h3_to_center_child(idx, refinement_levels[-1]), - h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]), - ) - else: - add_edge(graph, idx, idx_neighbour) - - -def create_hexagonal_nodes( - resolutions: list[int], - flat: bool = True, +def get_cells_at_resolution( + resolution: int, area: Optional[dict] = None, -) -> tuple[nx.Graph, torch.Tensor, list[int]]: - """Creates a global mesh from a refined icosahedro. - - This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each - refinement level, a hexagon cell has 7 child cells (aperture 7). +) -> set[str]: + """Get cells at a specified refinement level. Parameters ---------- - resolutions : list[int] - Levels of mesh resolution to consider. - flat : bool - Whether or not all resolution levels of the mesh are included. + resolution : int + The H3 refinement level. It can be an integer from 0 to 15. area : dict - A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global - mesh. + A region, in GeoJSON data format, to be contained by all cells. Defaults to None. aoi_mask_builder : KNNAreaMaskBuilder, optional - KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. + KNNAreaMaskBuilder computes nask to limit the mesh area, by default None. Returns ------- - graph : networkx.Graph - The specified graph (nodes & edges). + cells : set[str] + The set of H3 indexes at the specified resolution level. """ - graph = nx.Graph() - - area_kwargs = {"area": area} - - for resolution in resolutions: - add_nodes_for_resolution(graph, resolution, **area_kwargs) - - coords = np.array([h3.h3_to_geo(node) for node in graph.nodes]) - - # Sort nodes by latitude and longitude - node_ordering = np.lexsort(coords.T[::-1], axis=0) + # TODO: What is area? + cells = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) - # Should these be sorted here or in the edge builder? - coords = coords[node_ordering] + # TODO: AOI not used in the current implementation. - return graph, coords, node_ordering + return cells def add_edges_to_nx_graph( @@ -212,6 +163,28 @@ def add_self_loops(graph: nx.Graph) -> None: add_edge(graph, idx, idx, allow_self_loop=True) +def add_neighbour_edges( + graph: nx.Graph, + refinement_levels: tuple[int], + flat: bool = True, +) -> None: + for resolution in refinement_levels: + cells = {node for node in graph.nodes if h3.h3_get_resolution(node) == resolution} + for idx in cells: + k = 2 if resolution == 0 else 1 # refinement_levels[0]: # extra large field of vision ; only few nodes + + # neighbours + for idx_neighbour in h3.k_ring(idx, k=k) & cells: + if flat: + add_edge( + graph, + h3.h3_to_center_child(idx, refinement_levels[-1]), + h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]), + ) + else: + add_edge(graph, idx, idx_neighbour) + + def add_children_edges( graph: nx.Graph, refinement_levels: tuple[int], @@ -250,3 +223,33 @@ def add_children_edges( ) else: add_edge(graph, idx_parent, idx_child_neighbour) + + +def add_edge( + graph: nx.Graph, + idx1: str, + idx2: str, + allow_self_loop: bool = False, +) -> None: + """Add edge between two nodes to a graph. + + The edge will only be added in case both tail and head nodes are included in the graph, G. + + Parameters + ---------- + graph : networkx.Graph + The graph to add the nodes. + idx1 : str + The H3 index of the tail of the edge. + idx2 : str + The H3 index of the head of the edge. + allow_self_loop : bool + Whether to allow self-loops or not. Defaults to not allowing self-loops. + """ + if not graph.has_node(idx1) or not graph.has_node(idx2): + return + + if allow_self_loop or idx1 != idx2: + loc1 = np.deg2rad(h3.h3_to_geo(idx1)) + loc2 = np.deg2rad(h3.h3_to_geo(idx2)) + graph.add_edge(idx1, idx2, weight=haversine_distances([loc1, loc2])[0][1]) diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 97c4801..7f124cc 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Iterable from typing import Optional @@ -8,7 +9,6 @@ from sklearn.neighbors import BallTree from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad -import logging logger = logging.getLogger(__name__) @@ -37,7 +37,7 @@ def create_icosahedral_nodes( Order of the nodes in the graph to be sorted by latitude and longitude. """ sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) - + coords_rad = cartesian_to_latlon_rad(sphere.vertices) node_ordering = get_node_ordering(coords_rad) @@ -61,11 +61,11 @@ def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_orderin return graph -def get_node_ordering(vertices_rad: np.ndarray) -> np.ndarray: +def get_node_ordering(coords_rad: np.ndarray) -> np.ndarray: # Get indices to sort points by lon & lat in radians. - ind1 = np.argsort(vertices_rad[:, 1]) - ind2 = np.argsort(vertices_rad[ind1][:, 0])[::-1] - node_ordering = np.arange(vertices_rad.shape[0])[ind1][ind2] + ind1 = np.argsort(coords_rad[:, 1]) + ind2 = np.argsort(coords_rad[ind1][:, 0])[::-1] + node_ordering = np.arange(coords_rad.shape[0])[ind1][ind2] return node_ordering @@ -108,10 +108,10 @@ def add_edges_to_nx_graph( # TODO AOI mask builder is not used in the current implementation. valid_nodes = None - x_rings = get_x_hops(r_sphere, xhops, valid_nodes=valid_nodes) + x_hops = get_x_hops(r_sphere, xhops, valid_nodes=valid_nodes) _, idx = tree.query(r_vertices_rad, k=1) - for i, i_neighbours in x_rings.items(): + for i, i_neighbours in x_hops.items(): add_neigbours_edges(graph, r_vertices_rad, i, i_neighbours, idx=idx) return graph @@ -151,8 +151,8 @@ def get_x_hops(tri_mesh: trimesh.Trimesh, hops: int, valid_nodes: Optional[list[ def add_neigbours_edges( graph: nx.Graph, vertices: np.ndarray, - ii: int, - neighbours: Iterable[int], + node_idx: int, + neighbour_indices: Iterable[int], self_loops: bool = False, idx: Optional[np.ndarray] = None, ) -> None: @@ -164,7 +164,7 @@ def add_neigbours_edges( The graph. vertices : np.ndarray A 2D array of shape (num_vertices, 2) with the planar coordinates of the mesh, in radians. - ii : int + node_idx : int The node considered. neighbours : list[int] The neighbours of the node. @@ -173,26 +173,21 @@ def add_neigbours_edges( idx : np.ndarray, optional Index to map the vertices from the refined sphere to the original one, by default None. """ - for idx_neighbour in neighbours: - if not self_loops and ii == idx_neighbour: # no self-loops + for neighbour_idx in neighbour_indices: + if not self_loops and node_idx == neighbour_idx: # no self-loops continue - location_node = vertices[ii] - location_neighbour = vertices[idx_neighbour] + location_node = vertices[node_idx] + location_neighbour = vertices[neighbour_idx] edge_length = haversine_distances([location_neighbour, location_node])[0][1] if idx is not None: # Use the same method to add edge in all spheres - node_neighbour = idx[idx_neighbour][0] - node = idx[ii][0] + node_neighbour = idx[neighbour_idx][0] + node = idx[node_idx][0] else: - node, node_neighbour = ii, idx_neighbour + node, node_neighbour = node_idx, neighbour_idx # add edge to the graph if node in graph and node_neighbour in graph: graph.add_edge(node_neighbour, node, weight=edge_length) - - - - - From 46a2c07aa84b37c754d40ba9fc5578daea8aabc0 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 16:49:46 +0000 Subject: [PATCH 093/156] feat: ensure src and dst always the same --- src/anemoi/graphs/edges/builder.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 8904fb6..60da96b 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -15,8 +15,8 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral -from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder -from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder +from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodes +from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodes from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) @@ -270,8 +270,8 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor class TriIcosahedralEdges(BaseEdgeBuilder): """Computes icosahedral edges and adds them to a HeteroData graph.""" - def __init__(self, src_name: str, dst_name: str, xhops: int): - super().__init__(src_name, dst_name) + def __init__(self, src_name: str, xhops: int): + super().__init__(src_name, src_name) assert isinstance(xhops, int), "Number of xhops must be an integer" assert xhops > 0, "Number of xhops must be positive" @@ -281,11 +281,8 @@ def __init__(self, src_name: str, dst_name: str, xhops: int): def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( - graph[self.src_name].node_type == TriRefinedIcosahedralNodeBuilder.__name__ - ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." - assert ( - graph[self.src_name] == graph[self.dst_name] - ), "InheritConnection requires the same nodes for source and destination." + graph[self.src_name].node_type == TriRefinedIcosahedralNodes.__name__ + ), "IcosahedralConnection requires TriRefinedIcosahedralNodes." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." @@ -314,20 +311,15 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): class HexagonalEdges(BaseEdgeBuilder): """Computes hexagonal edges and adds them to a HeteroData graph.""" - def __init__( - self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1 - ): - super().__init__(src_name, dst_name) + def __init__(self, src_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): + super().__init__(src_name, src_name) self.add_neighbouring_children = add_neighbouring_children self.depth_children = depth_children def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( - graph[self.src_name].node_type == HexRefinedIcosahedralNodeBuilder.__name__ - ), "IcosahedralConnection requires MultiScaleIcosahedral nodes." - assert ( - graph[self.src_name] == graph[self.dst_name] - ), "InheritConnection requires the same nodes for source and destination." + graph[self.src_name].node_type == HexRefinedIcosahedralNodes.__name__ + ), "HexagonalEdges requires HexRefinedIcosahedralNodes." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." From c9e4fdeb7e3f2f6b96a6fcbd2131726adb73a224 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 16:50:01 +0000 Subject: [PATCH 094/156] fix: imports --- src/anemoi/graphs/edges/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 53b9c74..e95e3cd 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -1,4 +1,6 @@ from .builder import CutOffEdges +from .builder import HexagonalEdges from .builder import KNNEdges +from .builder import TriIcosahedralEdges -__all__ = ["KNNEdges", "CutOffEdges"] +__all__ = ["KNNEdges", "CutOffEdges", "TriIcosahedralEdges", "HexagonalEdges"] From 7a8b316548b08e18de23afcdf359d913c27b07f2 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 17:22:41 +0000 Subject: [PATCH 095/156] fix: edge_name not supported --- src/anemoi/graphs/edges/builder.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 60da96b..f1514e0 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -278,7 +278,7 @@ def __init__(self, src_name: str, xhops: int): self.xhops = xhops - def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: + def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( graph[self.src_name].node_type == TriRefinedIcosahedralNodes.__name__ @@ -287,7 +287,7 @@ def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[Do # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." - return super().transform(graph, edge_name, attrs_config) + return super().transform(graph, attrs_config) def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): @@ -295,7 +295,6 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): src_nodes["nx_graph"], resolutions=src_nodes["resolutions"], xhops=self.xhops, - aoi_nneighb=None if "aoi_nneighb" not in src_nodes else src_nodes["aoi_nneigh"], ) # HeteroData refuses to accept None adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], nodelist=list(src_nodes["nx_graph"]), format="coo") @@ -316,7 +315,7 @@ def __init__(self, src_name: str, add_neighbouring_children: bool = False, depth self.add_neighbouring_children = add_neighbouring_children self.depth_children = depth_children - def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData: + def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( graph[self.src_name].node_type == HexRefinedIcosahedralNodes.__name__ ), "HexagonalEdges requires HexRefinedIcosahedralNodes." @@ -324,7 +323,7 @@ def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[Do # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." - return super().transform(graph, edge_name, attrs_config) + return super().transform(graph, attrs_config) def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): From 02496ec60f09a283a3a7c515729242ee2b5de18f Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 17:23:25 +0000 Subject: [PATCH 096/156] test: add tests for TriIcosahedralEdges --- tests/edges/test_tri_icosahedral_edges.py | 42 +++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/edges/test_tri_icosahedral_edges.py diff --git a/tests/edges/test_tri_icosahedral_edges.py b/tests/edges/test_tri_icosahedral_edges.py new file mode 100644 index 0000000..9663cb0 --- /dev/null +++ b/tests/edges/test_tri_icosahedral_edges.py @@ -0,0 +1,42 @@ +import pytest +from torch_geometric.data import HeteroData + +from anemoi.graphs.edges import TriIcosahedralEdges +from anemoi.graphs.nodes import TriRefinedIcosahedralNodes + + +class TestTriIcosahedralEdgesInit: + def test_init(self): + """Test TriIcosahedralEdges initialization.""" + assert isinstance(TriIcosahedralEdges("test_nodes", 1), TriIcosahedralEdges) + + @pytest.mark.parametrize("xhops", [-0.5, "hello", None, -4]) + def test_fail_init(self, xhops: str): + """Test TriIcosahedralEdges initialization with invalid cutoff.""" + with pytest.raises(AssertionError): + TriIcosahedralEdges("test_nodes", xhops) + + +class TestTriIcosahedralEdgesTransform: + + @pytest.fixture() + def ico_graph(self) -> HeteroData: + """Return a HeteroData object with TriRefinedIcosahedralNodes.""" + graph = HeteroData() + graph = TriRefinedIcosahedralNodes(0).transform(graph, "test_nodes", {}) + graph["fail_nodes"].x = [1, 2, 3] + graph["fail_nodes"].node_type = "FailNodes" + return graph + + def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData): + """Test TriIcosahedralEdges transform method.""" + + tri_icosahedral_edges = TriIcosahedralEdges("test_nodes", 1) + graph = tri_icosahedral_edges.transform(ico_graph) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types + + def test_transform_fail_nodes(self, ico_graph: HeteroData): + """Test TriIcosahedralEdges transform method with wrong node type.""" + tri_icosahedral_edges = TriIcosahedralEdges("fail_nodes", 1) + with pytest.raises(AssertionError): + tri_icosahedral_edges.transform(ico_graph) From 59fcad34916093dcd75d9c59d09fd61fff1e69ef Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 2 Jul 2024 08:45:13 +0000 Subject: [PATCH 097/156] fix: assert missing for Hexagonal edges --- src/anemoi/graphs/edges/builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index f1514e0..290cafc 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -313,6 +313,9 @@ class HexagonalEdges(BaseEdgeBuilder): def __init__(self, src_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): super().__init__(src_name, src_name) self.add_neighbouring_children = add_neighbouring_children + + assert isinstance(depth_children, int), "Depth of children must be an integer" + assert depth_children > 0, "Depth of children must be positive" self.depth_children = depth_children def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: From b41f88ecf1a8e819a6942565fdd4e1d5533f559f Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 2 Jul 2024 08:46:08 +0000 Subject: [PATCH 098/156] test: hexagonal edges --- tests/edges/test_hex_refined_icosahedral.py | 42 +++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/edges/test_hex_refined_icosahedral.py diff --git a/tests/edges/test_hex_refined_icosahedral.py b/tests/edges/test_hex_refined_icosahedral.py new file mode 100644 index 0000000..d6bb29c --- /dev/null +++ b/tests/edges/test_hex_refined_icosahedral.py @@ -0,0 +1,42 @@ +import pytest +from torch_geometric.data import HeteroData + +from anemoi.graphs.edges import HexagonalEdges +from anemoi.graphs.nodes import HexRefinedIcosahedralNodes + + +class TestTriIcosahedralEdgesInit: + def test_init(self): + """Test TriIcosahedralEdges initialization.""" + assert isinstance(HexagonalEdges("test_nodes"), HexagonalEdges) + + @pytest.mark.parametrize("depth_children", [-0.5, "hello", None, -4]) + def test_fail_init(self, depth_children: str): + """Test HexagonalEdges initialization with invalid cutoff.""" + with pytest.raises(AssertionError): + HexagonalEdges("test_nodes", True, depth_children) + + +class TestTriIcosahedralEdgesTransform: + + @pytest.fixture() + def ico_graph(self) -> HeteroData: + """Return a HeteroData object with HexRefinedIcosahedralNodes.""" + graph = HeteroData() + graph = HexRefinedIcosahedralNodes(0).transform(graph, "test_nodes", {}) + graph["fail_nodes"].x = [1, 2, 3] + graph["fail_nodes"].node_type = "FailNodes" + return graph + + def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData): + """Test HexagonalEdges transform method.""" + + tri_icosahedral_edges = HexagonalEdges("test_nodes") + graph = tri_icosahedral_edges.transform(ico_graph) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types + + def test_transform_fail_nodes(self, ico_graph: HeteroData): + """Test HexagonalEdges transform method with wrong node type.""" + tri_icosahedral_edges = HexagonalEdges("fail_nodes") + with pytest.raises(AssertionError): + tri_icosahedral_edges.transform(ico_graph) From 5a431854b321ecbbf24bb6505e214c18eb1544d3 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 2 Jul 2024 09:49:12 +0000 Subject: [PATCH 099/156] fix: avoid same name --- .../{test_hex_refined_icosahedral.py => test_hexagonal_edges.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/edges/{test_hex_refined_icosahedral.py => test_hexagonal_edges.py} (100%) diff --git a/tests/edges/test_hex_refined_icosahedral.py b/tests/edges/test_hexagonal_edges.py similarity index 100% rename from tests/edges/test_hex_refined_icosahedral.py rename to tests/edges/test_hexagonal_edges.py From a0259f8a68983bbabe7089a7b31e8c1bfb218dd6 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 8 Jul 2024 14:40:28 +0000 Subject: [PATCH 100/156] fix: imports --- src/anemoi/graphs/edges/attributes.py | 28 ++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index f54510b..371ea82 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -6,6 +6,8 @@ import numpy as np import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate from torch_geometric.data import HeteroData from anemoi.graphs.edges.directional import directional_edge_features @@ -14,9 +16,10 @@ logger = logging.getLogger(__name__) -class NodeAttributeBuilder(): - def transform(self, graph: HeteroData, graph_config: DotDict): +class NodeAttributeBuilder: + + def transform(self, graph: HeteroData, graph_config: DotDict): for name, nodes_cfg in graph_config.nodes.items(): graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {})) @@ -24,14 +27,17 @@ def transform(self, graph: HeteroData, graph_config: DotDict): def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict): assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph." for attr_name, attr_cfg in node_config.items(): - graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) + graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) return graph -class EdgeAttributeBuilder(): + +class EdgeAttributeBuilder: def transform(self, graph: HeteroData, graph_config: DotDict): for edges_cfg in graph_config.edges: - graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) + graph = self.register_edge_attributes( + graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {}) + ) return graph def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict): @@ -40,12 +46,16 @@ def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: s attr_values = instantiate(attr_cfg).compute(graph, src_name, dst_name) graph = self.register_edge_attribute(graph, src_name, dst_name, attr_name, attr_values) return graph - - def register_edge_attribute(self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor): + + def register_edge_attribute( + self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor + ): num_edges = graph[(src_name, "to", dst_name)].num_edges - assert ( attr_values.shape[0] == num_edges), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})." + assert ( + attr_values.shape[0] == num_edges + ), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})." - graph[(src_name, "to", dst_name)][attr_name] = attr_values + graph[(src_name, "to", dst_name)][attr_name] = attr_values return graph From 7289e3290f638fed0cbd92c966202b70cbc95b13 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 8 Jul 2024 15:21:25 +0000 Subject: [PATCH 101/156] fix: conflicts --- src/anemoi/graphs/nodes/builder.py | 3 --- src/anemoi/graphs/normalizer.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index dbe3bfa..27f374a 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -24,9 +24,6 @@ class BaseNodeBuilder(ABC): The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. """ - def __init__(self, name: str) -> None: - self.name = name - def __init__(self) -> None: self.aoi_mask_builder = None diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index bdadfab..c625417 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -2,7 +2,7 @@ import numpy as np -logger = logging.getLogger(__name__) +LOGGER = logging.getLogger(__name__) class NormalizerMixin: From 4ca717bbe5c9423a6ccc4cae0f7a1515bedcea57 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 9 Jul 2024 08:58:11 +0000 Subject: [PATCH 102/156] update tests --- src/anemoi/graphs/create.py | 1 - src/anemoi/graphs/edges/builder.py | 60 ++++++++++++--------- src/anemoi/graphs/nodes/builder.py | 19 +++---- tests/edges/test_attributes.py | 20 ------- tests/edges/test_hexagonal_edges.py | 19 ++++--- tests/edges/test_tri_icosahedral_edges.py | 25 +++++---- tests/nodes/test_hex_refined_icosahedral.py | 27 +++++----- tests/nodes/test_tri_refined_icosahedral.py | 27 +++++----- 8 files changed, 98 insertions(+), 100 deletions(-) delete mode 100644 tests/edges/test_attributes.py diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index c4f1030..bf3cb0c 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -26,7 +26,6 @@ def __init__( else: self.config = config - self.path = path # Output path self.cache = cache self.print = print diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 290cafc..d27c95d 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -270,36 +270,39 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor class TriIcosahedralEdges(BaseEdgeBuilder): """Computes icosahedral edges and adds them to a HeteroData graph.""" - def __init__(self, src_name: str, xhops: int): - super().__init__(src_name, src_name) + def __init__(self, source_name: str, target_name: str, xhops: int): + super().__init__(source_name, target_name) + assert source_name == target_name, "TriIcosahedralEdges requires source and target nodes to be the same." assert isinstance(xhops, int), "Number of xhops must be an integer" assert xhops > 0, "Number of xhops must be positive" self.xhops = xhops - def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( - graph[self.src_name].node_type == TriRefinedIcosahedralNodes.__name__ - ), "IcosahedralConnection requires TriRefinedIcosahedralNodes." + graph[self.source_name].node_type == TriRefinedIcosahedralNodes.__name__ + ), f"{self.__class__.__name__} requires TriRefinedIcosahedralNodes." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." - return super().transform(graph, attrs_config) + return super().update_graph(graph, attrs_config) - def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): - src_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph( - src_nodes["nx_graph"], - resolutions=src_nodes["resolutions"], + source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph( + source_nodes["nx_graph"], + resolutions=source_nodes["resolutions"], xhops=self.xhops, ) # HeteroData refuses to accept None - adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], nodelist=list(src_nodes["nx_graph"]), format="coo") - graph_1_sorted = dict(zip(range(len(src_nodes["nx_graph"].nodes)), list(src_nodes["nx_graph"].nodes))) - graph_2_sorted = dict(zip(src_nodes.node_ordering, range(len(src_nodes.node_ordering)))) + adjmat = nx.to_scipy_sparse_array( + source_nodes["nx_graph"], nodelist=list(source_nodes["nx_graph"]), format="coo" + ) + graph_1_sorted = dict(zip(range(len(source_nodes["nx_graph"].nodes)), list(source_nodes["nx_graph"].nodes))) + graph_2_sorted = dict(zip(source_nodes.node_ordering, range(len(source_nodes.node_ordering)))) sort_func1 = np.vectorize(graph_1_sorted.get) sort_func2 = np.vectorize(graph_2_sorted.get) adjmat.row = sort_func2(sort_func1(adjmat.row)) @@ -310,35 +313,42 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): class HexagonalEdges(BaseEdgeBuilder): """Computes hexagonal edges and adds them to a HeteroData graph.""" - def __init__(self, src_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1): - super().__init__(src_name, src_name) + def __init__( + self, + source_name: str, + target_name: str, + add_neighbouring_children: bool = False, + depth_children: Optional[int] = 1, + ): + super().__init__(source_name, source_name) self.add_neighbouring_children = add_neighbouring_children + assert source_name == target_name, "TriIcosahedralEdges requires source and target nodes to be the same." assert isinstance(depth_children, int), "Depth of children must be an integer" assert depth_children > 0, "Depth of children must be positive" self.depth_children = depth_children - def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: assert ( - graph[self.src_name].node_type == HexRefinedIcosahedralNodes.__name__ - ), "HexagonalEdges requires HexRefinedIcosahedralNodes." + graph[self.source_name].node_type == HexRefinedIcosahedralNodes.__name__ + ), f"{self.__class__.__name__} requires HexRefinedIcosahedralNodes." # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." - return super().transform(graph, attrs_config) + return super().update_graph(graph, attrs_config) - def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): - src_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph( - src_nodes["nx_graph"], - resolutions=src_nodes["resolutions"], + source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph( + source_nodes["nx_graph"], + resolutions=source_nodes["resolutions"], neighbour_children=self.add_neighbouring_children, depth_children=self.depth_children, ) - adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], format="coo") - graph_2_sorted = dict(zip(src_nodes["node_ordering"], range(len(src_nodes.node_ordering)))) + adjmat = nx.to_scipy_sparse_array(source_nodes["nx_graph"], format="coo") + graph_2_sorted = dict(zip(source_nodes["node_ordering"], range(len(source_nodes.node_ordering)))) sort_func = np.vectorize(graph_2_sorted.get) adjmat.row = sort_func(adjmat.row) adjmat.col = sort_func(adjmat.col) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index f184387..393d2fa 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -211,31 +211,28 @@ class RefinedIcosahedralNodes(BaseNodeBuilder, ABC): def __init__( self, resolution: Union[int, list[int]], - np_dtype: np.dtype = np.float32, + name: str, ) -> None: - # TODO: Discuss np_dtype - self.np_dtype = np_dtype - if isinstance(resolution, int): self.resolutions = list(range(resolution + 1)) else: self.resolutions = resolution - super().__init__() + super().__init__(name) def get_coordinates(self) -> torch.Tensor: self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() - return torch.tensor(coords_rad[self.node_ordering]) + return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) @abstractmethod def create_nodes(self) -> np.ndarray: ... - def register_attributes(self, graph: HeteroData, name: str, config: DotDict) -> HeteroData: - graph[name]["resolutions"] = self.resolutions - graph[name]["nx_graph"] = self.nx_graph - graph[name]["node_ordering"] = self.node_ordering + def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: + graph[self.name]["resolutions"] = self.resolutions + graph[self.name]["nx_graph"] = self.nx_graph + graph[self.name]["node_ordering"] = self.node_ordering # TODO: AOI mask builder is not used in the current implementation. - return super().register_attributes(graph, name, config) + return super().register_attributes(graph, config) class TriRefinedIcosahedralNodes(RefinedIcosahedralNodes): diff --git a/tests/edges/test_attributes.py b/tests/edges/test_attributes.py deleted file mode 100644 index dcd756d..0000000 --- a/tests/edges/test_attributes.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest -import torch - -from anemoi.graphs.edges.attributes import DirectionalFeatures - - -@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"]) -@pytest.mark.parametrize("luse_rotated_features", [True, False]) -def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features: bool): - """Test DirectionalFeatures compute method.""" - edge_attr_builder = DirectionalFeatures(norm=norm, luse_rotated_features=luse_rotated_features) - edge_attr = edge_attr_builder(graph_nodes_and_edges, "test_nodes", "test_nodes") - assert isinstance(edge_attr, torch.Tensor) - - -def test_fail_directional_features(graph_nodes_and_edges): - """Test DirectionalFeatures compute method.""" - edge_attr_builder = DirectionalFeatures() - with pytest.raises(AttributeError): - edge_attr_builder(graph_nodes_and_edges, "test_nodes", "unknown_nodes") diff --git a/tests/edges/test_hexagonal_edges.py b/tests/edges/test_hexagonal_edges.py index d6bb29c..95cfce1 100644 --- a/tests/edges/test_hexagonal_edges.py +++ b/tests/edges/test_hexagonal_edges.py @@ -8,13 +8,18 @@ class TestTriIcosahedralEdgesInit: def test_init(self): """Test TriIcosahedralEdges initialization.""" - assert isinstance(HexagonalEdges("test_nodes"), HexagonalEdges) + assert isinstance(HexagonalEdges("test_nodes", "test_nodes"), HexagonalEdges) @pytest.mark.parametrize("depth_children", [-0.5, "hello", None, -4]) def test_fail_init(self, depth_children: str): """Test HexagonalEdges initialization with invalid cutoff.""" with pytest.raises(AssertionError): - HexagonalEdges("test_nodes", True, depth_children) + HexagonalEdges("test_nodes", "test_nodes", True, depth_children) + + def test_fail_init_diff_nodes(self): + """Test HexagonalEdges initialization with invalid nodes.""" + with pytest.raises(AssertionError): + HexagonalEdges("test_nodes", "test_nodes2", 0) class TestTriIcosahedralEdgesTransform: @@ -23,7 +28,7 @@ class TestTriIcosahedralEdgesTransform: def ico_graph(self) -> HeteroData: """Return a HeteroData object with HexRefinedIcosahedralNodes.""" graph = HeteroData() - graph = HexRefinedIcosahedralNodes(0).transform(graph, "test_nodes", {}) + graph = HexRefinedIcosahedralNodes(0, "test_nodes").update_graph(graph, {}) graph["fail_nodes"].x = [1, 2, 3] graph["fail_nodes"].node_type = "FailNodes" return graph @@ -31,12 +36,12 @@ def ico_graph(self) -> HeteroData: def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData): """Test HexagonalEdges transform method.""" - tri_icosahedral_edges = HexagonalEdges("test_nodes") - graph = tri_icosahedral_edges.transform(ico_graph) + tri_icosahedral_edges = HexagonalEdges("test_nodes", "test_nodes") + graph = tri_icosahedral_edges.update_graph(ico_graph) assert ("test_nodes", "to", "test_nodes") in graph.edge_types def test_transform_fail_nodes(self, ico_graph: HeteroData): """Test HexagonalEdges transform method with wrong node type.""" - tri_icosahedral_edges = HexagonalEdges("fail_nodes") + tri_icosahedral_edges = HexagonalEdges("fail_nodes", "fail_nodes") with pytest.raises(AssertionError): - tri_icosahedral_edges.transform(ico_graph) + tri_icosahedral_edges.update_graph(ico_graph) diff --git a/tests/edges/test_tri_icosahedral_edges.py b/tests/edges/test_tri_icosahedral_edges.py index 9663cb0..53c0518 100644 --- a/tests/edges/test_tri_icosahedral_edges.py +++ b/tests/edges/test_tri_icosahedral_edges.py @@ -8,13 +8,18 @@ class TestTriIcosahedralEdgesInit: def test_init(self): """Test TriIcosahedralEdges initialization.""" - assert isinstance(TriIcosahedralEdges("test_nodes", 1), TriIcosahedralEdges) + assert isinstance(TriIcosahedralEdges("test_nodes", "test_nodes", 1), TriIcosahedralEdges) @pytest.mark.parametrize("xhops", [-0.5, "hello", None, -4]) def test_fail_init(self, xhops: str): - """Test TriIcosahedralEdges initialization with invalid cutoff.""" + """Test TriIcosahedralEdges initialization with invalid xhops.""" with pytest.raises(AssertionError): - TriIcosahedralEdges("test_nodes", xhops) + TriIcosahedralEdges("test_nodes", "test_nodes", xhops) + + def test_fail_init_diff_nodes(self): + """Test TriIcosahedralEdges initialization with invalid nodes.""" + with pytest.raises(AssertionError): + TriIcosahedralEdges("test_nodes", "test_nodes2", 0) class TestTriIcosahedralEdgesTransform: @@ -23,20 +28,20 @@ class TestTriIcosahedralEdgesTransform: def ico_graph(self) -> HeteroData: """Return a HeteroData object with TriRefinedIcosahedralNodes.""" graph = HeteroData() - graph = TriRefinedIcosahedralNodes(0).transform(graph, "test_nodes", {}) + graph = TriRefinedIcosahedralNodes(1, "test_nodes").update_graph(graph, {}) graph["fail_nodes"].x = [1, 2, 3] graph["fail_nodes"].node_type = "FailNodes" return graph def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData): - """Test TriIcosahedralEdges transform method.""" + """Test TriIcosahedralEdges update method.""" - tri_icosahedral_edges = TriIcosahedralEdges("test_nodes", 1) - graph = tri_icosahedral_edges.transform(ico_graph) + tri_icosahedral_edges = TriIcosahedralEdges("test_nodes", "test_nodes", 1) + graph = tri_icosahedral_edges.update_graph(ico_graph) assert ("test_nodes", "to", "test_nodes") in graph.edge_types def test_transform_fail_nodes(self, ico_graph: HeteroData): - """Test TriIcosahedralEdges transform method with wrong node type.""" - tri_icosahedral_edges = TriIcosahedralEdges("fail_nodes", 1) + """Test TriIcosahedralEdges update method with wrong node type.""" + tri_icosahedral_edges = TriIcosahedralEdges("fail_nodes", "fail_nodes", 1) with pytest.raises(AssertionError): - tri_icosahedral_edges.transform(ico_graph) + tri_icosahedral_edges.update_graph(ico_graph) diff --git a/tests/nodes/test_hex_refined_icosahedral.py b/tests/nodes/test_hex_refined_icosahedral.py index df0e716..bd84f48 100644 --- a/tests/nodes/test_hex_refined_icosahedral.py +++ b/tests/nodes/test_hex_refined_icosahedral.py @@ -2,32 +2,33 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import builder +from anemoi.graphs.nodes.builder import BaseNodeBuilder +from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodes @pytest.mark.parametrize("resolution", [0, 2]) def test_init(resolution: list[int]): """Test TrirefinedIcosahedralNodes initialization.""" - node_builder = builder.HexRefinedIcosahedralNodes(resolution) - assert isinstance(node_builder, builder.BaseNodeBuilder) - assert isinstance(node_builder, builder.HexRefinedIcosahedralNodes) + node_builder = HexRefinedIcosahedralNodes(resolution, "test_nodes") + assert isinstance(node_builder, BaseNodeBuilder) + assert isinstance(node_builder, HexRefinedIcosahedralNodes) def test_get_coordinates(): """Test get_coordinates method.""" - node_builder = builder.HexRefinedIcosahedralNodes(0) + node_builder = HexRefinedIcosahedralNodes(0, "test_nodes") coords = node_builder.get_coordinates() assert isinstance(coords, torch.Tensor) assert coords.shape == (122, 2) -def test_transform(): - """Test transform method.""" - node_builder = builder.HexRefinedIcosahedralNodes(0) +def test_update_graph(): + """Test update_graph method.""" + node_builder = HexRefinedIcosahedralNodes(0, "test_nodes") graph = HeteroData() - graph = node_builder.transform(graph, "test", {}) - assert "resolutions" in graph["test"] - assert "nx_graph" in graph["test"] - assert "node_ordering" in graph["test"] - assert len(graph["test"]["node_ordering"]) == graph["test"].num_nodes + graph = node_builder.update_graph(graph, {}) + assert "resolutions" in graph["test_nodes"] + assert "nx_graph" in graph["test_nodes"] + assert "node_ordering" in graph["test_nodes"] + assert len(graph["test_nodes"]["node_ordering"]) == graph["test_nodes"].num_nodes diff --git a/tests/nodes/test_tri_refined_icosahedral.py b/tests/nodes/test_tri_refined_icosahedral.py index 762efdf..c3b008e 100644 --- a/tests/nodes/test_tri_refined_icosahedral.py +++ b/tests/nodes/test_tri_refined_icosahedral.py @@ -2,32 +2,33 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import builder +from anemoi.graphs.nodes.builder import BaseNodeBuilder +from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodes @pytest.mark.parametrize("resolution", [0, 2]) def test_init(resolution: list[int]): """Test TrirefinedIcosahedralNodes initialization.""" - node_builder = builder.TriRefinedIcosahedralNodes(resolution) - assert isinstance(node_builder, builder.BaseNodeBuilder) - assert isinstance(node_builder, builder.TriRefinedIcosahedralNodes) + node_builder = TriRefinedIcosahedralNodes(resolution, "test_nodes") + assert isinstance(node_builder, BaseNodeBuilder) + assert isinstance(node_builder, TriRefinedIcosahedralNodes) def test_get_coordinates(): """Test get_coordinates method.""" - node_builder = builder.TriRefinedIcosahedralNodes(2) + node_builder = TriRefinedIcosahedralNodes(2, "test_nodes") coords = node_builder.get_coordinates() assert isinstance(coords, torch.Tensor) assert coords.shape == (162, 2) -def test_transform(): - """Test transform method.""" - node_builder = builder.TriRefinedIcosahedralNodes(1) +def test_update_graph(): + """Test update_graph method.""" + node_builder = TriRefinedIcosahedralNodes(1, "test_nodes") graph = HeteroData() - graph = node_builder.transform(graph, "test", {}) - assert "resolutions" in graph["test"] - assert "nx_graph" in graph["test"] - assert "node_ordering" in graph["test"] - assert len(graph["test"]["node_ordering"]) == graph["test"].num_nodes + graph = node_builder.update_graph(graph, {}) + assert "resolutions" in graph["test_nodes"] + assert "nx_graph" in graph["test_nodes"] + assert "node_ordering" in graph["test_nodes"] + assert len(graph["test_nodes"]["node_ordering"]) == graph["test_nodes"].num_nodes From fe8a8e53ee429fd05ceca44a0eefc089b2b5f261 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 9 Jul 2024 10:58:20 +0000 Subject: [PATCH 103/156] Include xhops to hexagonal edges --- src/anemoi/graphs/edges/builder.py | 82 ++++++++++--------------- src/anemoi/graphs/generate/hexagonal.py | 10 +-- tests/edges/test_hexagonal_edges.py | 14 ++--- 3 files changed, 47 insertions(+), 59 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index d27c95d..43033f9 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -15,6 +15,7 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral +from anemoi.graphs.nodes.builder import BaseNodeBuilder from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodes from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodes from anemoi.graphs.utils import get_grid_reference_distance @@ -267,31 +268,42 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor return adj_matrix -class TriIcosahedralEdges(BaseEdgeBuilder): - """Computes icosahedral edges and adds them to a HeteroData graph.""" +class MultiScaleEdges(BaseEdgeBuilder, ABC): + """Base class for multi-scale edges in the nodes of a graph.""" def __init__(self, source_name: str, target_name: str, xhops: int): super().__init__(source_name, target_name) - - assert source_name == target_name, "TriIcosahedralEdges requires source and target nodes to be the same." + assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same." assert isinstance(xhops, int), "Number of xhops must be an integer" assert xhops > 0, "Number of xhops must be positive" - self.xhops = xhops - def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + @abstractmethod + def base_node_class(self) -> BaseNodeBuilder: ... - assert ( - graph[self.source_name].node_type == TriRefinedIcosahedralNodes.__name__ - ), f"{self.__class__.__name__} requires TriRefinedIcosahedralNodes." + def post_process_adjmat(self, nodes: NodeStorage, adjmat): + graph_sorted = dict(zip(nodes["node_ordering"], range(len(nodes.node_ordering)))) + sort_func = np.vectorize(graph_sorted.get) + adjmat.row = sort_func(adjmat.row) + adjmat.col = sort_func(adjmat.col) + return adjmat - # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate - # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." + def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData: + assert ( + graph[self.source_name].node_type == self.base_node_class.__name__ + ), f"{self.__class__.__name__} requires {self.base_node_class.__name__}." return super().update_graph(graph, attrs_config) - def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): +class TriIcosahedralEdges(MultiScaleEdges): + """Computes icosahedral edges and adds them to a HeteroData graph.""" + + @property + def base_node_class(self) -> BaseNodeBuilder: + return TriRefinedIcosahedralNodes + + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph( source_nodes["nx_graph"], resolutions=source_nodes["resolutions"], @@ -302,54 +314,28 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor source_nodes["nx_graph"], nodelist=list(source_nodes["nx_graph"]), format="coo" ) graph_1_sorted = dict(zip(range(len(source_nodes["nx_graph"].nodes)), list(source_nodes["nx_graph"].nodes))) - graph_2_sorted = dict(zip(source_nodes.node_ordering, range(len(source_nodes.node_ordering)))) sort_func1 = np.vectorize(graph_1_sorted.get) - sort_func2 = np.vectorize(graph_2_sorted.get) - adjmat.row = sort_func2(sort_func1(adjmat.row)) - adjmat.col = sort_func2(sort_func1(adjmat.col)) + adjmat.row = sort_func1(adjmat.row) + adjmat.col = sort_func1(adjmat.col) + + self.post_process_adjmat(source_nodes, adjmat) return adjmat -class HexagonalEdges(BaseEdgeBuilder): +class HexagonalEdges(MultiScaleEdges): """Computes hexagonal edges and adds them to a HeteroData graph.""" - def __init__( - self, - source_name: str, - target_name: str, - add_neighbouring_children: bool = False, - depth_children: Optional[int] = 1, - ): - super().__init__(source_name, source_name) - self.add_neighbouring_children = add_neighbouring_children - - assert source_name == target_name, "TriIcosahedralEdges requires source and target nodes to be the same." - assert isinstance(depth_children, int), "Depth of children must be an integer" - assert depth_children > 0, "Depth of children must be positive" - self.depth_children = depth_children - - def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: - assert ( - graph[self.source_name].node_type == HexRefinedIcosahedralNodes.__name__ - ), f"{self.__class__.__name__} requires HexRefinedIcosahedralNodes." - - # TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate - # assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes." - - return super().update_graph(graph, attrs_config) + @property + def base_node_class(self) -> BaseNodeBuilder: + return HexRefinedIcosahedralNodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): - source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph( source_nodes["nx_graph"], resolutions=source_nodes["resolutions"], - neighbour_children=self.add_neighbouring_children, - depth_children=self.depth_children, + xhops=self.xhops, ) adjmat = nx.to_scipy_sparse_array(source_nodes["nx_graph"], format="coo") - graph_2_sorted = dict(zip(source_nodes["node_ordering"], range(len(source_nodes.node_ordering)))) - sort_func = np.vectorize(graph_2_sorted.get) - adjmat.row = sort_func(adjmat.row) - adjmat.col = sort_func(adjmat.col) + self.post_process_adjmat(source_nodes, adjmat) return adjmat diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 4ab9569..dc534e0 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -112,6 +112,7 @@ def get_cells_at_resolution( def add_edges_to_nx_graph( graph: nx.Graph, resolutions: list[int], + xhops: int = 1, self_loop: bool = False, flat: bool = True, neighbour_children: bool = False, @@ -128,6 +129,8 @@ def add_edges_to_nx_graph( The graph to add the nodes. resolutions : list[int] Levels of mesh resolution to consider. + xhops: int + The number of hops to consider for the neighbours. self_loop : bool Whether include a self-loop in every node or not. flat : bool @@ -146,7 +149,7 @@ def add_edges_to_nx_graph( if self_loop: add_self_loops(graph) - add_neighbour_edges(graph, resolutions, flat) + add_neighbour_edges(graph, resolutions, xhops, flat) add_children_edges( graph, resolutions, @@ -166,15 +169,14 @@ def add_self_loops(graph: nx.Graph) -> None: def add_neighbour_edges( graph: nx.Graph, refinement_levels: tuple[int], + xhops: int = 1, flat: bool = True, ) -> None: for resolution in refinement_levels: cells = {node for node in graph.nodes if h3.h3_get_resolution(node) == resolution} for idx in cells: - k = 2 if resolution == 0 else 1 # refinement_levels[0]: # extra large field of vision ; only few nodes - # neighbours - for idx_neighbour in h3.k_ring(idx, k=k) & cells: + for idx_neighbour in h3.k_ring(idx, k=xhops) & cells: if flat: add_edge( graph, diff --git a/tests/edges/test_hexagonal_edges.py b/tests/edges/test_hexagonal_edges.py index 95cfce1..08a580d 100644 --- a/tests/edges/test_hexagonal_edges.py +++ b/tests/edges/test_hexagonal_edges.py @@ -8,13 +8,13 @@ class TestTriIcosahedralEdgesInit: def test_init(self): """Test TriIcosahedralEdges initialization.""" - assert isinstance(HexagonalEdges("test_nodes", "test_nodes"), HexagonalEdges) + assert isinstance(HexagonalEdges("test_nodes", "test_nodes", 1), HexagonalEdges) - @pytest.mark.parametrize("depth_children", [-0.5, "hello", None, -4]) - def test_fail_init(self, depth_children: str): - """Test HexagonalEdges initialization with invalid cutoff.""" + @pytest.mark.parametrize("xhops", [-0.5, "hello", None, -4]) + def test_fail_init(self, xhops: int): + """Test HexagonalEdges initialization with invalid xhops.""" with pytest.raises(AssertionError): - HexagonalEdges("test_nodes", "test_nodes", True, depth_children) + HexagonalEdges("test_nodes", "test_nodes", xhops) def test_fail_init_diff_nodes(self): """Test HexagonalEdges initialization with invalid nodes.""" @@ -36,12 +36,12 @@ def ico_graph(self) -> HeteroData: def test_transform_same_src_dst_nodes(self, ico_graph: HeteroData): """Test HexagonalEdges transform method.""" - tri_icosahedral_edges = HexagonalEdges("test_nodes", "test_nodes") + tri_icosahedral_edges = HexagonalEdges("test_nodes", "test_nodes", 1) graph = tri_icosahedral_edges.update_graph(ico_graph) assert ("test_nodes", "to", "test_nodes") in graph.edge_types def test_transform_fail_nodes(self, ico_graph: HeteroData): """Test HexagonalEdges transform method with wrong node type.""" - tri_icosahedral_edges = HexagonalEdges("fail_nodes", "fail_nodes") + tri_icosahedral_edges = HexagonalEdges("fail_nodes", "fail_nodes", 1) with pytest.raises(AssertionError): tri_icosahedral_edges.update_graph(ico_graph) From 463911c68842bc2b0aeb60ac515f4d89fa46bdf5 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 9 Jul 2024 15:43:44 +0000 Subject: [PATCH 104/156] docs: update docstrings --- src/anemoi/graphs/edges/builder.py | 6 ++---- src/anemoi/graphs/nodes/builder.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 43033f9..f4335fb 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -39,7 +39,7 @@ def name(self) -> tuple[str, str, str]: def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ... def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: - """Prepare nodes information.""" + """Prepare node information and get source and target nodes.""" return graph[self.source_name], graph[self.target_name] def get_edge_index(self, graph: HeteroData) -> torch.Tensor: @@ -194,8 +194,6 @@ class CutOffEdges(BaseEdgeBuilder): The name of the target nodes. cutoff_factor : float Factor to multiply the grid reference distance to get the cut-off radius. - radius : float - Cut-off radius. Methods ------- @@ -241,7 +239,7 @@ def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] return radius def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: - """Prepare nodes information.""" + """Prepare node information and get source and target nodes.""" self.radius = self.get_cutoff_radius(graph) return super().prepare_node_data(graph) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index c161fd7..4197361 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -22,6 +22,11 @@ class BaseNodeBuilder(ABC): """Base class for node builders. The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. + + Attributes + ---------- + name : str + name of the nodes, key for the nodes in the HeteroData graph object. """ def __init__(self, name: str) -> None: @@ -110,7 +115,7 @@ class ZarrDatasetNodes(BaseNodeBuilder): Attributes ---------- - ds : zarr.core.Array + dataset : zarr.core.Array The dataset. Methods @@ -127,7 +132,7 @@ class ZarrDatasetNodes(BaseNodeBuilder): def __init__(self, dataset: DotDict, name: str) -> None: LOGGER.info("Reading the dataset from %s.", dataset) - self.ds = open_dataset(dataset) + self.dataset = open_dataset(dataset) super().__init__(name) def get_coordinates(self) -> torch.Tensor: @@ -138,7 +143,7 @@ def get_coordinates(self) -> torch.Tensor: torch.Tensor of shape (N, 2) Coordinates of the nodes. """ - return self.reshape_coords(self.ds.latitudes, self.ds.longitudes) + return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) class NPZFileNodes(BaseNodeBuilder): @@ -203,8 +208,6 @@ class RefinedIcosahedralNodes(BaseNodeBuilder, ABC): ---------- resolution : list[int] | int Refinement level of the mesh. - np_dtype : np.dtype, optional - The numpy data type to use, by default np.float32. """ def __init__( @@ -230,7 +233,6 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: graph[self.name]["resolutions"] = self.resolutions graph[self.name]["nx_graph"] = self.nx_graph graph[self.name]["node_ordering"] = self.node_ordering - # TODO: AOI mask builder is not used in the current implementation. return super().register_attributes(graph, config) @@ -238,7 +240,6 @@ class TriRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the trimesh Python library.""" def create_nodes(self) -> np.ndarray: - # TODO: AOI mask builder is not used in the current implementation. return create_icosahedral_nodes(resolutions=self.resolutions) @@ -246,5 +247,4 @@ class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the h3 Python library.""" def create_nodes(self) -> np.ndarray: - # TODO: AOI mask builder is not used in the current implementation. return create_hexagonal_nodes(self.resolutions) From b0a35b8d05011047e764f3ca2c89fb33d9225365 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 10 Jul 2024 14:38:05 +0000 Subject: [PATCH 105/156] fix: update attribute name --- tests/nodes/test_zarr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index e3a2687..e7c98cc 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -33,7 +33,7 @@ def test_register_nodes(mocker, mock_zarr_dataset): assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) - assert graph["test_nodes"].x.shape == (node_builder.ds.num_nodes, 2) + assert graph["test_nodes"].x.shape == (node_builder.dataset.num_nodes, 2) assert graph["test_nodes"].node_type == "ZarrDatasetNodes" From 4f445a927f74829ab487fd196c3972c314c01663 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Thu, 11 Jul 2024 09:08:14 +0000 Subject: [PATCH 106/156] refactor: rename multiscale nodes --- src/anemoi/graphs/edges/builder.py | 8 ++++---- src/anemoi/graphs/nodes/__init__.py | 6 +++--- src/anemoi/graphs/nodes/builder.py | 6 +++--- tests/edges/test_hexagonal_edges.py | 4 ++-- tests/edges/test_tri_icosahedral_edges.py | 4 ++-- ..._refined_icosahedral.py => test_hexagonal_nodes.py} | 10 +++++----- ...ed_icosahedral.py => test_tri_icosahedral_nodes.py} | 10 +++++----- 7 files changed, 24 insertions(+), 24 deletions(-) rename tests/nodes/{test_hex_refined_icosahedral.py => test_hexagonal_nodes.py} (72%) rename tests/nodes/{test_tri_refined_icosahedral.py => test_tri_icosahedral_nodes.py} (72%) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index f4335fb..eb72f5b 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -16,8 +16,8 @@ from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral from anemoi.graphs.nodes.builder import BaseNodeBuilder -from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodes -from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodes +from anemoi.graphs.nodes.builder import HexagonalNodes +from anemoi.graphs.nodes.builder import TriIcosahedralNodes from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) @@ -299,7 +299,7 @@ class TriIcosahedralEdges(MultiScaleEdges): @property def base_node_class(self) -> BaseNodeBuilder: - return TriRefinedIcosahedralNodes + return TriIcosahedralNodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph( @@ -325,7 +325,7 @@ class HexagonalEdges(MultiScaleEdges): @property def base_node_class(self) -> BaseNodeBuilder: - return HexRefinedIcosahedralNodes + return HexagonalNodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph( diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index ef13e41..02eba93 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,6 +1,6 @@ -from .builder import HexRefinedIcosahedralNodes +from .builder import HexagonalNodes from .builder import NPZFileNodes -from .builder import TriRefinedIcosahedralNodes +from .builder import TriIcosahedralNodes from .builder import ZarrDatasetNodes -__all__ = ["ZarrDatasetNodes", "NPZFileNodes", "TriRefinedIcosahedralNodes", "HexRefinedIcosahedralNodes"] +__all__ = ["ZarrDatasetNodes", "NPZFileNodes", "TriIcosahedralNodes", "HexagonalNodes"] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 4197361..7cdd040 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -199,7 +199,7 @@ def get_coordinates(self) -> torch.Tensor: return coords -class RefinedIcosahedralNodes(BaseNodeBuilder, ABC): +class MultiScaleNodes(BaseNodeBuilder, ABC): """Processor mesh based on a triangular mesh. It is based on the icosahedral mesh, which is a mesh of triangles that covers the sphere. @@ -236,14 +236,14 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: return super().register_attributes(graph, config) -class TriRefinedIcosahedralNodes(RefinedIcosahedralNodes): +class TriIcosahedralNodes(MultiScaleNodes): """It depends on the trimesh Python library.""" def create_nodes(self) -> np.ndarray: return create_icosahedral_nodes(resolutions=self.resolutions) -class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes): +class HexagonalNodes(MultiScaleNodes): """It depends on the h3 Python library.""" def create_nodes(self) -> np.ndarray: diff --git a/tests/edges/test_hexagonal_edges.py b/tests/edges/test_hexagonal_edges.py index 08a580d..16774c1 100644 --- a/tests/edges/test_hexagonal_edges.py +++ b/tests/edges/test_hexagonal_edges.py @@ -2,7 +2,7 @@ from torch_geometric.data import HeteroData from anemoi.graphs.edges import HexagonalEdges -from anemoi.graphs.nodes import HexRefinedIcosahedralNodes +from anemoi.graphs.nodes import HexagonalNodes class TestTriIcosahedralEdgesInit: @@ -28,7 +28,7 @@ class TestTriIcosahedralEdgesTransform: def ico_graph(self) -> HeteroData: """Return a HeteroData object with HexRefinedIcosahedralNodes.""" graph = HeteroData() - graph = HexRefinedIcosahedralNodes(0, "test_nodes").update_graph(graph, {}) + graph = HexagonalNodes(0, "test_nodes").update_graph(graph, {}) graph["fail_nodes"].x = [1, 2, 3] graph["fail_nodes"].node_type = "FailNodes" return graph diff --git a/tests/edges/test_tri_icosahedral_edges.py b/tests/edges/test_tri_icosahedral_edges.py index 53c0518..24c95e4 100644 --- a/tests/edges/test_tri_icosahedral_edges.py +++ b/tests/edges/test_tri_icosahedral_edges.py @@ -2,7 +2,7 @@ from torch_geometric.data import HeteroData from anemoi.graphs.edges import TriIcosahedralEdges -from anemoi.graphs.nodes import TriRefinedIcosahedralNodes +from anemoi.graphs.nodes import TriIcosahedralNodes class TestTriIcosahedralEdgesInit: @@ -28,7 +28,7 @@ class TestTriIcosahedralEdgesTransform: def ico_graph(self) -> HeteroData: """Return a HeteroData object with TriRefinedIcosahedralNodes.""" graph = HeteroData() - graph = TriRefinedIcosahedralNodes(1, "test_nodes").update_graph(graph, {}) + graph = TriIcosahedralNodes(1, "test_nodes").update_graph(graph, {}) graph["fail_nodes"].x = [1, 2, 3] graph["fail_nodes"].node_type = "FailNodes" return graph diff --git a/tests/nodes/test_hex_refined_icosahedral.py b/tests/nodes/test_hexagonal_nodes.py similarity index 72% rename from tests/nodes/test_hex_refined_icosahedral.py rename to tests/nodes/test_hexagonal_nodes.py index bd84f48..8722d98 100644 --- a/tests/nodes/test_hex_refined_icosahedral.py +++ b/tests/nodes/test_hexagonal_nodes.py @@ -2,22 +2,22 @@ import torch from torch_geometric.data import HeteroData +from anemoi.graphs.nodes import HexagonalNodes from anemoi.graphs.nodes.builder import BaseNodeBuilder -from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodes @pytest.mark.parametrize("resolution", [0, 2]) def test_init(resolution: list[int]): """Test TrirefinedIcosahedralNodes initialization.""" - node_builder = HexRefinedIcosahedralNodes(resolution, "test_nodes") + node_builder = HexagonalNodes(resolution, "test_nodes") assert isinstance(node_builder, BaseNodeBuilder) - assert isinstance(node_builder, HexRefinedIcosahedralNodes) + assert isinstance(node_builder, HexagonalNodes) def test_get_coordinates(): """Test get_coordinates method.""" - node_builder = HexRefinedIcosahedralNodes(0, "test_nodes") + node_builder = HexagonalNodes(0, "test_nodes") coords = node_builder.get_coordinates() assert isinstance(coords, torch.Tensor) assert coords.shape == (122, 2) @@ -25,7 +25,7 @@ def test_get_coordinates(): def test_update_graph(): """Test update_graph method.""" - node_builder = HexRefinedIcosahedralNodes(0, "test_nodes") + node_builder = HexagonalNodes(0, "test_nodes") graph = HeteroData() graph = node_builder.update_graph(graph, {}) assert "resolutions" in graph["test_nodes"] diff --git a/tests/nodes/test_tri_refined_icosahedral.py b/tests/nodes/test_tri_icosahedral_nodes.py similarity index 72% rename from tests/nodes/test_tri_refined_icosahedral.py rename to tests/nodes/test_tri_icosahedral_nodes.py index c3b008e..2d9f8ca 100644 --- a/tests/nodes/test_tri_refined_icosahedral.py +++ b/tests/nodes/test_tri_icosahedral_nodes.py @@ -2,22 +2,22 @@ import torch from torch_geometric.data import HeteroData +from anemoi.graphs.nodes import TriIcosahedralNodes from anemoi.graphs.nodes.builder import BaseNodeBuilder -from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodes @pytest.mark.parametrize("resolution", [0, 2]) def test_init(resolution: list[int]): """Test TrirefinedIcosahedralNodes initialization.""" - node_builder = TriRefinedIcosahedralNodes(resolution, "test_nodes") + node_builder = TriIcosahedralNodes(resolution, "test_nodes") assert isinstance(node_builder, BaseNodeBuilder) - assert isinstance(node_builder, TriRefinedIcosahedralNodes) + assert isinstance(node_builder, TriIcosahedralNodes) def test_get_coordinates(): """Test get_coordinates method.""" - node_builder = TriRefinedIcosahedralNodes(2, "test_nodes") + node_builder = TriIcosahedralNodes(2, "test_nodes") coords = node_builder.get_coordinates() assert isinstance(coords, torch.Tensor) assert coords.shape == (162, 2) @@ -25,7 +25,7 @@ def test_get_coordinates(): def test_update_graph(): """Test update_graph method.""" - node_builder = TriRefinedIcosahedralNodes(1, "test_nodes") + node_builder = TriIcosahedralNodes(1, "test_nodes") graph = HeteroData() graph = node_builder.update_graph(graph, {}) assert "resolutions" in graph["test_nodes"] From fa812eb77d9a8adca2e98d715c28ccfd3ca9dbeb Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Thu, 11 Jul 2024 10:38:10 +0000 Subject: [PATCH 107/156] refactor: rename icosahedral nodes --- src/anemoi/graphs/edges/builder.py | 8 ++++---- src/anemoi/graphs/nodes/__init__.py | 6 +++--- src/anemoi/graphs/nodes/builder.py | 6 +++--- tests/edges/test_hexagonal_edges.py | 4 ++-- tests/edges/test_tri_icosahedral_edges.py | 4 ++-- .../{test_hexagonal_nodes.py => test_hex_nodes.py} | 10 +++++----- ...test_tri_icosahedral_nodes.py => test_tri_nodes.py} | 10 +++++----- 7 files changed, 24 insertions(+), 24 deletions(-) rename tests/nodes/{test_hexagonal_nodes.py => test_hex_nodes.py} (77%) rename tests/nodes/{test_tri_icosahedral_nodes.py => test_tri_nodes.py} (75%) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index eb72f5b..053f10f 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -16,8 +16,8 @@ from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral from anemoi.graphs.nodes.builder import BaseNodeBuilder -from anemoi.graphs.nodes.builder import HexagonalNodes -from anemoi.graphs.nodes.builder import TriIcosahedralNodes +from anemoi.graphs.nodes.builder import HexNodes +from anemoi.graphs.nodes.builder import TriNodes from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) @@ -299,7 +299,7 @@ class TriIcosahedralEdges(MultiScaleEdges): @property def base_node_class(self) -> BaseNodeBuilder: - return TriIcosahedralNodes + return TriNodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph( @@ -325,7 +325,7 @@ class HexagonalEdges(MultiScaleEdges): @property def base_node_class(self) -> BaseNodeBuilder: - return HexagonalNodes + return HexNodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph( diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 02eba93..7b6b149 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,6 +1,6 @@ -from .builder import HexagonalNodes +from .builder import HexNodes from .builder import NPZFileNodes -from .builder import TriIcosahedralNodes +from .builder import TriNodes from .builder import ZarrDatasetNodes -__all__ = ["ZarrDatasetNodes", "NPZFileNodes", "TriIcosahedralNodes", "HexagonalNodes"] +__all__ = ["ZarrDatasetNodes", "NPZFileNodes", "TriNodes", "HexNodes"] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 7cdd040..2e1c965 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -199,7 +199,7 @@ def get_coordinates(self) -> torch.Tensor: return coords -class MultiScaleNodes(BaseNodeBuilder, ABC): +class IcosahedralNodes(BaseNodeBuilder, ABC): """Processor mesh based on a triangular mesh. It is based on the icosahedral mesh, which is a mesh of triangles that covers the sphere. @@ -236,14 +236,14 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: return super().register_attributes(graph, config) -class TriIcosahedralNodes(MultiScaleNodes): +class TriNodes(IcosahedralNodes): """It depends on the trimesh Python library.""" def create_nodes(self) -> np.ndarray: return create_icosahedral_nodes(resolutions=self.resolutions) -class HexagonalNodes(MultiScaleNodes): +class HexNodes(IcosahedralNodes): """It depends on the h3 Python library.""" def create_nodes(self) -> np.ndarray: diff --git a/tests/edges/test_hexagonal_edges.py b/tests/edges/test_hexagonal_edges.py index 16774c1..63b79a5 100644 --- a/tests/edges/test_hexagonal_edges.py +++ b/tests/edges/test_hexagonal_edges.py @@ -2,7 +2,7 @@ from torch_geometric.data import HeteroData from anemoi.graphs.edges import HexagonalEdges -from anemoi.graphs.nodes import HexagonalNodes +from anemoi.graphs.nodes import HexNodes class TestTriIcosahedralEdgesInit: @@ -28,7 +28,7 @@ class TestTriIcosahedralEdgesTransform: def ico_graph(self) -> HeteroData: """Return a HeteroData object with HexRefinedIcosahedralNodes.""" graph = HeteroData() - graph = HexagonalNodes(0, "test_nodes").update_graph(graph, {}) + graph = HexNodes(0, "test_nodes").update_graph(graph, {}) graph["fail_nodes"].x = [1, 2, 3] graph["fail_nodes"].node_type = "FailNodes" return graph diff --git a/tests/edges/test_tri_icosahedral_edges.py b/tests/edges/test_tri_icosahedral_edges.py index 24c95e4..80cae2c 100644 --- a/tests/edges/test_tri_icosahedral_edges.py +++ b/tests/edges/test_tri_icosahedral_edges.py @@ -2,7 +2,7 @@ from torch_geometric.data import HeteroData from anemoi.graphs.edges import TriIcosahedralEdges -from anemoi.graphs.nodes import TriIcosahedralNodes +from anemoi.graphs.nodes import TriNodes class TestTriIcosahedralEdgesInit: @@ -28,7 +28,7 @@ class TestTriIcosahedralEdgesTransform: def ico_graph(self) -> HeteroData: """Return a HeteroData object with TriRefinedIcosahedralNodes.""" graph = HeteroData() - graph = TriIcosahedralNodes(1, "test_nodes").update_graph(graph, {}) + graph = TriNodes(1, "test_nodes").update_graph(graph, {}) graph["fail_nodes"].x = [1, 2, 3] graph["fail_nodes"].node_type = "FailNodes" return graph diff --git a/tests/nodes/test_hexagonal_nodes.py b/tests/nodes/test_hex_nodes.py similarity index 77% rename from tests/nodes/test_hexagonal_nodes.py rename to tests/nodes/test_hex_nodes.py index 8722d98..0753fa2 100644 --- a/tests/nodes/test_hexagonal_nodes.py +++ b/tests/nodes/test_hex_nodes.py @@ -2,7 +2,7 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import HexagonalNodes +from anemoi.graphs.nodes import HexNodes from anemoi.graphs.nodes.builder import BaseNodeBuilder @@ -10,14 +10,14 @@ def test_init(resolution: list[int]): """Test TrirefinedIcosahedralNodes initialization.""" - node_builder = HexagonalNodes(resolution, "test_nodes") + node_builder = HexNodes(resolution, "test_nodes") assert isinstance(node_builder, BaseNodeBuilder) - assert isinstance(node_builder, HexagonalNodes) + assert isinstance(node_builder, HexNodes) def test_get_coordinates(): """Test get_coordinates method.""" - node_builder = HexagonalNodes(0, "test_nodes") + node_builder = HexNodes(0, "test_nodes") coords = node_builder.get_coordinates() assert isinstance(coords, torch.Tensor) assert coords.shape == (122, 2) @@ -25,7 +25,7 @@ def test_get_coordinates(): def test_update_graph(): """Test update_graph method.""" - node_builder = HexagonalNodes(0, "test_nodes") + node_builder = HexNodes(0, "test_nodes") graph = HeteroData() graph = node_builder.update_graph(graph, {}) assert "resolutions" in graph["test_nodes"] diff --git a/tests/nodes/test_tri_icosahedral_nodes.py b/tests/nodes/test_tri_nodes.py similarity index 75% rename from tests/nodes/test_tri_icosahedral_nodes.py rename to tests/nodes/test_tri_nodes.py index 2d9f8ca..3c84aad 100644 --- a/tests/nodes/test_tri_icosahedral_nodes.py +++ b/tests/nodes/test_tri_nodes.py @@ -2,7 +2,7 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import TriIcosahedralNodes +from anemoi.graphs.nodes import TriNodes from anemoi.graphs.nodes.builder import BaseNodeBuilder @@ -10,14 +10,14 @@ def test_init(resolution: list[int]): """Test TrirefinedIcosahedralNodes initialization.""" - node_builder = TriIcosahedralNodes(resolution, "test_nodes") + node_builder = TriNodes(resolution, "test_nodes") assert isinstance(node_builder, BaseNodeBuilder) - assert isinstance(node_builder, TriIcosahedralNodes) + assert isinstance(node_builder, TriNodes) def test_get_coordinates(): """Test get_coordinates method.""" - node_builder = TriIcosahedralNodes(2, "test_nodes") + node_builder = TriNodes(2, "test_nodes") coords = node_builder.get_coordinates() assert isinstance(coords, torch.Tensor) assert coords.shape == (162, 2) @@ -25,7 +25,7 @@ def test_get_coordinates(): def test_update_graph(): """Test update_graph method.""" - node_builder = TriIcosahedralNodes(1, "test_nodes") + node_builder = TriNodes(1, "test_nodes") graph = HeteroData() graph = node_builder.update_graph(graph, {}) assert "resolutions" in graph["test_nodes"] From d190758d04471d7f41548d99741861a9ee46487d Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Thu, 11 Jul 2024 10:57:12 +0000 Subject: [PATCH 108/156] refactor: LimitedArea prefix --- src/anemoi/graphs/nodes/__init__.py | 12 ++++++------ src/anemoi/graphs/nodes/builder.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index b0d5201..0d1263c 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,7 +1,7 @@ -from .builder import AreaHexRefinedIcosahedralNodes -from .builder import AreaNPZFileNodes -from .builder import AreaTriRefinedIcosahedralNodes from .builder import HexNodes +from .builder import LimitedAreaHexNodes +from .builder import LimitedAreaNPZFileNodes +from .builder import LimitedAreaTriNodes from .builder import LimitedAreaZarrDatasetNodes from .builder import NPZFileNodes from .builder import TriNodes @@ -13,7 +13,7 @@ "TriNodes", "HexNodes", "LimitedAreaZarrDatasetNodes", - "AreaNPZFileNodes", - "AreaTriRefinedIcosahedralNodes", - "AreaHexRefinedIcosahedralNodes", + "LimitedAreaNPZFileNodes", + "LimitedAreaTriNodes", + "LimitedAreaHexNodes", ] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 6f82cf1..0529005 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -223,7 +223,7 @@ def get_coordinates(self) -> torch.Tensor: return coords -class AreaNPZFileNodes(NPZFileNodes): +class LimitedAreaNPZFileNodes(NPZFileNodes): """Processor mesh based on an NPZ defined grids using an area of interest.""" def __init__( @@ -311,7 +311,7 @@ def create_nodes(self) -> np.ndarray: return create_hexagonal_nodes(self.resolutions) -class AreaTriRefinedIcosahedralNodes(TriNodes): +class LimitedAreaTriNodes(TriNodes): """Class to build icosahedral nodes with a limited area of interest.""" def __init__( @@ -332,7 +332,7 @@ def register_nodes(self, graph: HeteroData) -> None: return super().register_nodes(graph) -class AreaHexRefinedIcosahedralNodes(HexNodes): +class LimitedAreaHexNodes(HexNodes): """Class to build icosahedral nodes with a limited area of interest.""" def __init__( From dca3729492881583d7df8c1b32b9c636d2088051 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Thu, 11 Jul 2024 15:49:52 +0000 Subject: [PATCH 109/156] feat: add aoi_mask_builder to edge builder --- src/anemoi/graphs/edges/builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 053f10f..6281485 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -306,6 +306,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor source_nodes["nx_graph"], resolutions=source_nodes["resolutions"], xhops=self.xhops, + aoi_mask_builder=None if "aoi_mask_builder" not in source_nodes else source_nodes["aoi_mask_builder"], ) # HeteroData refuses to accept None adjmat = nx.to_scipy_sparse_array( From 9d015a38a3a5491330dbbe06269cf7cd75b90f70 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 09:43:52 +0000 Subject: [PATCH 110/156] docs & default values --- src/anemoi/graphs/nodes/masks.py | 36 ++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/anemoi/graphs/nodes/masks.py b/src/anemoi/graphs/nodes/masks.py index 786796f..48a3283 100644 --- a/src/anemoi/graphs/nodes/masks.py +++ b/src/anemoi/graphs/nodes/masks.py @@ -6,9 +6,28 @@ class KNNAreaMaskBuilder: - """Class to build a mask based on distance to masked reference nodes using KNN.""" - - def __init__(self, reference_node_name: str, margin_radius_km: float, mask_attr_name: str): + """Class to build a mask based on distance to masked reference nodes using KNN. + + Attributes + ---------- + nearest_neighbour : NearestNeighbors + Nearest neighbour object to compute the KNN. + margin_radius_km : float + Maximum distance to the reference nodes to consider a node as valid, in kilometers. Defaults to 100 km. + reference_node_name : str + Name of the reference nodes in the graph to consider for the Area Mask. + mask_attr_name : str + Name of a node to attribute to mask the reference nodes, if desired. Defaults to consider all reference nodes. + + Methods + ------- + fit(graph: HeteroData) + Fit the KNN model to the reference nodes. + get_mask(coords_rad: np.ndarray) -> np.ndarray + Get the mask for the nodes based on the distance to the reference nodes. + """ + + def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask_attr_name: str = None): self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) self.margin_radius_km = margin_radius_km @@ -16,11 +35,16 @@ def __init__(self, reference_node_name: str, margin_radius_km: float, mask_attr_ self.mask_attr_name = mask_attr_name def fit(self, graph: HeteroData): + """Fit the KNN model to the nodes of interest.""" coords_rad = graph[self.reference_node_name].x.numpy() - mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() - self.nearest_neighbour.fit(coords_rad[mask, :]) + if self.mask_attr_name is not None: + mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() + coords_rad = coords_rad[mask] + + self.nearest_neighbour.fit(coords_rad) - def get_mask(self, coords_rad: np.ndarray): + def get_mask(self, coords_rad: np.ndarray) -> np.ndarray: + """Compute a mask based on the distance to the reference nodes.""" neigh_dists, _ = self.nearest_neighbour.kneighbors(coords_rad, n_neighbors=1) mask = neigh_dists[:, 0] * EARTH_RADIUS <= self.margin_radius_km From 41bd3f51b68464516280cf42f77527f18df70d8f Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 09:47:47 +0000 Subject: [PATCH 111/156] create LimitedAreaHEALPixNodes --- src/anemoi/graphs/nodes/builder.py | 109 +++++++++++++++++++++++------ 1 file changed, 86 insertions(+), 23 deletions(-) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index ae070c0..9bb7031 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -28,6 +28,8 @@ class BaseNodeBuilder(ABC): ---------- name : str name of the nodes, key for the nodes in the HeteroData graph object. + aoi_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder, if any. Defaults to None. """ def __init__(self, name: str) -> None: @@ -67,21 +69,21 @@ def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = Non return graph @abstractmethod - def get_coordinates(self) -> np.ndarray: ... + def get_coordinates(self) -> torch.Tensor: ... - def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray: + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch.Tensor: """Reshape latitude and longitude coordinates. Parameters ---------- - latitudes : np.ndarray of shape (N, ) + latitudes : np.ndarray of shape (num_nodes, ) Latitude coordinates, in degrees. - longitudes : np.ndarray of shape (N, ) + longitudes : np.ndarray of shape (num_nodes, ) Longitude coordinates, in degrees. Returns ------- - torch.Tensor of shape (N, 2) + torch.Tensor of shape (num_nodes, 2) A 2D tensor with the coordinates, in radians. """ coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) @@ -143,8 +145,8 @@ def get_coordinates(self) -> torch.Tensor: Returns ------- - torch.Tensor of shape (N, 2) - Coordinates of the nodes. + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. """ return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) @@ -216,15 +218,15 @@ def get_coordinates(self) -> torch.Tensor: Returns ------- - torch.Tensor of shape (N, 2) - Coordinates of the nodes. + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. """ coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) return coords class LimitedAreaNPZFileNodes(NPZFileNodes): - """Processor mesh based on an NPZ defined grids using an area of interest.""" + """Nodes from NPZ defined grids using an area of interest.""" def __init__( self, @@ -251,7 +253,7 @@ def get_coordinates(self) -> np.ndarray: "Limiting the processor mesh to a radius of %.2f km from the output mesh.", self.aoi_mask_builder.margin_radius_km, ) - aoi_mask = self.aoi_mask_builder.get_mask(np.deg2rad(coords)) + aoi_mask = self.aoi_mask_builder.get_mask(coords) LOGGER.info("Dropping %d nodes from the processor mesh.", len(aoi_mask) - aoi_mask.sum()) coords = coords[aoi_mask] @@ -260,11 +262,9 @@ def get_coordinates(self) -> np.ndarray: class IcosahedralNodes(BaseNodeBuilder, ABC): - """Processor mesh based on a triangular mesh. + """Nodes based on iterative refinements of an icosahedron. - It is based on the icosahedral mesh, which is a mesh of triangles that covers the sphere. - - Parameters + Attributes ---------- resolution : list[int] | int Refinement level of the mesh. @@ -283,6 +283,13 @@ def __init__( super().__init__(name) def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) @@ -298,21 +305,35 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: class TriNodes(IcosahedralNodes): - """It depends on the trimesh Python library.""" + """Nodes based on iterative refinements of an icosahedron. + + It depends on the trimesh Python library. + """ def create_nodes(self) -> np.ndarray: return create_icosahedral_nodes(resolutions=self.resolutions) class HexNodes(IcosahedralNodes): - """It depends on the h3 Python library.""" + """Nodes based on iterative refinements of an icosahedron. + + It depends on the h3 Python library. + """ def create_nodes(self) -> np.ndarray: return create_hexagonal_nodes(self.resolutions) class LimitedAreaTriNodes(TriNodes): - """Class to build icosahedral nodes with a limited area of interest.""" + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + It depends on the trimesh Python library. + + Parameters + ---------- + aoi_mask_builder: KNNAreaMaskBuilder + The area of interest mask builder. + """ def __init__( self, @@ -333,7 +354,15 @@ def register_nodes(self, graph: HeteroData) -> None: class LimitedAreaHexNodes(HexNodes): - """Class to build icosahedral nodes with a limited area of interest.""" + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + It depends on the h3 Python library. + + Parameters + ---------- + aoi_mask_builder: KNNAreaMaskBuilder + The area of interest mask builder. + """ def __init__( self, @@ -362,8 +391,6 @@ class HEALPixNodes(BaseNodeBuilder): ---------- resolution : int The resolution of the grid. - name : str - The name of the nodes. Methods ------- @@ -390,8 +417,8 @@ def get_coordinates(self) -> torch.Tensor: Returns ------- - torch.Tensor of shape (N, 2) - Coordinates of the nodes. + torch.Tensor of shape (num_nodes, 2) + Coordinates of the nodes, in radians. """ import healpy as hp @@ -402,3 +429,39 @@ def get_coordinates(self) -> torch.Tensor: hpxlon, hpxlat = hp.pix2ang(2**self.resolution, range(npix), nest=True, lonlat=True) return self.reshape_coords(hpxlat, hpxlon) + + +class LimitedAreaHEALPixNodes(HEALPixNodes): + """Nodes from HEALPix grid using an area of interest.""" + + def __init__( + self, + resolution: str, + name: str, + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + ) -> None: + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + super().__init__(resolution, name) + + def register_nodes(self, graph: HeteroData) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph) + + def get_coordinates(self) -> np.ndarray: + coords = super().get_coordinates() + + LOGGER.info( + 'Limiting the "%s" nodes to a radius of %.2f km from the nodes of interest.', + self.name, + self.aoi_mask_builder.margin_radius_km, + ) + aoi_mask = self.aoi_mask_builder.get_mask(coords) + + LOGGER.info('Masking out %d nodes from "%s".', len(aoi_mask) - aoi_mask.sum(), self.name) + coords = coords[aoi_mask] + + return coords From 030168cc49f8935cf525d3e1b759fd094e7172f0 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 10:03:17 +0000 Subject: [PATCH 112/156] fix: import HEALPixNodes --- src/anemoi/graphs/nodes/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 0d1263c..2214d37 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,6 @@ +from .builder import HEALPixNodes from .builder import HexNodes +from .builder import LimitedAreaHEALPixNodes from .builder import LimitedAreaHexNodes from .builder import LimitedAreaNPZFileNodes from .builder import LimitedAreaTriNodes @@ -12,6 +14,8 @@ "NPZFileNodes", "TriNodes", "HexNodes", + "HEALPixNodes", + "LimitedAreaHEALPixNodes", "LimitedAreaZarrDatasetNodes", "LimitedAreaNPZFileNodes", "LimitedAreaTriNodes", From 9493933d25e4358736de91a35a719dbfd508f925 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 10:37:45 +0000 Subject: [PATCH 113/156] fix: avoid runtimeError when deleting a key --- src/anemoi/graphs/create.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 6a7eff8..cf68e40 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -59,7 +59,7 @@ def clean(self, graph: HeteroData) -> HeteroData: cleaned graph """ for type_name in chain(graph.node_types, graph.edge_types): - for attr_name in graph[type_name].keys(): + for attr_name in list(graph[type_name].keys()): if attr_name.startswith("_"): del graph[type_name][attr_name] From 986efe757ec089d98d6938cd1084d0eb3ae5a5b4 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 10:38:28 +0000 Subject: [PATCH 114/156] fix: set config arg to pathlib.Path --- src/anemoi/graphs/commands/create.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/graphs/commands/create.py b/src/anemoi/graphs/commands/create.py index cbf28b2..73d1125 100644 --- a/src/anemoi/graphs/commands/create.py +++ b/src/anemoi/graphs/commands/create.py @@ -18,7 +18,7 @@ def add_arguments(self, command_parser): help="Overwrite existing files. This will delete the target graph if it already exists.", ) command_parser.add_argument( - "config", help="Configuration yaml file path defining the recipe to create the graph." + "config", type=Path, help="Configuration yaml file path defining the recipe to create the graph." ) command_parser.add_argument("save_path", type=Path, help="Path to store the created graph.") From c9c736dc6c4d3a874a5e673aa7a1125ccdcdb89f Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 11:20:12 +0000 Subject: [PATCH 115/156] more logging --- src/anemoi/graphs/nodes/masks.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/nodes/masks.py b/src/anemoi/graphs/nodes/masks.py index 48a3283..e1a4128 100644 --- a/src/anemoi/graphs/nodes/masks.py +++ b/src/anemoi/graphs/nodes/masks.py @@ -1,9 +1,13 @@ +import logging + import numpy as np from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData from anemoi.graphs import EARTH_RADIUS +LOGGER = logging.getLogger(__name__) + class KNNAreaMaskBuilder: """Class to build a mask based on distance to masked reference nodes using KNN. @@ -36,11 +40,19 @@ def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask def fit(self, graph: HeteroData): """Fit the KNN model to the nodes of interest.""" + reference_mask_str = self.reference_node_name coords_rad = graph[self.reference_node_name].x.numpy() if self.mask_attr_name is not None: mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() coords_rad = coords_rad[mask] - + reference_mask_str += f" ({self.mask_attr_name})" + + LOGGER.info( + 'Fitting %s with %d reference nodes from "%s".', + self.__class__.__name__, + len(coords_rad), + reference_mask_str, + ) self.nearest_neighbour.fit(coords_rad) def get_mask(self, coords_rad: np.ndarray) -> np.ndarray: From 6e571b30937db11f016a3eef509a58d423482442 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 11:22:01 +0000 Subject: [PATCH 116/156] create LimitedAreaIcoshaderalNodes --- src/anemoi/graphs/nodes/builder.py | 55 ++++++++++++++++-------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 9bb7031..0769df9 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -304,6 +304,29 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: return super().register_attributes(graph, config) +class LimitedAreaIcosahedralNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + Attributes + ---------- + aoi_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def __init__( + self, + resolution: int | list[int], + name: str, + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + ) -> None: + + super().__init__(resolution, name) + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + class TriNodes(IcosahedralNodes): """Nodes based on iterative refinements of an icosahedron. @@ -324,7 +347,7 @@ def create_nodes(self) -> np.ndarray: return create_hexagonal_nodes(self.resolutions) -class LimitedAreaTriNodes(TriNodes): +class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): """Nodes based on iterative refinements of an icosahedron using an area of interest. It depends on the trimesh Python library. @@ -335,25 +358,15 @@ class LimitedAreaTriNodes(TriNodes): The area of interest mask builder. """ - def __init__( - self, - resolution: int | list[int], - name: str, - reference_node_name: str, - mask_attr_name: str, - margin_radius_km: float = 100.0, - ) -> None: - - super().__init__(resolution, name) - - self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + def create_nodes(self) -> np.ndarray: + return create_icosahedral_nodes(resolutions=self.resolutions, aoi_mask_builder=self.aoi_mask_builder) def register_nodes(self, graph: HeteroData) -> None: self.aoi_mask_builder.fit(graph) return super().register_nodes(graph) -class LimitedAreaHexNodes(HexNodes): +class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): """Nodes based on iterative refinements of an icosahedron using an area of interest. It depends on the h3 Python library. @@ -364,18 +377,8 @@ class LimitedAreaHexNodes(HexNodes): The area of interest mask builder. """ - def __init__( - self, - resolution: int | list[int], - name: str, - reference_node_name: str, - mask_attr_name: str, - margin_radius_km: float = 100.0, - ) -> None: - - super().__init__(resolution, name) - - self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + def create_nodes(self) -> np.ndarray: + return create_hexagonal_nodes(self.resolutions, aoi_mask_builder=self.aoi_mask_builder) def register_nodes(self, graph: HeteroData) -> None: self.aoi_mask_builder.fit(graph) From 81a3419d444620cce49b179703ae109e4256f396 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 15:40:37 +0000 Subject: [PATCH 117/156] refactor LimiteAreaIcosahedralNodes class --- src/anemoi/graphs/nodes/builder.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 0769df9..5fbd64a 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -326,6 +326,10 @@ def __init__( self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + def register_nodes(self, graph: HeteroData) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph) + class TriNodes(IcosahedralNodes): """Nodes based on iterative refinements of an icosahedron. @@ -361,10 +365,6 @@ class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): def create_nodes(self) -> np.ndarray: return create_icosahedral_nodes(resolutions=self.resolutions, aoi_mask_builder=self.aoi_mask_builder) - def register_nodes(self, graph: HeteroData) -> None: - self.aoi_mask_builder.fit(graph) - return super().register_nodes(graph) - class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): """Nodes based on iterative refinements of an icosahedron using an area of interest. @@ -380,10 +380,6 @@ class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): def create_nodes(self) -> np.ndarray: return create_hexagonal_nodes(self.resolutions, aoi_mask_builder=self.aoi_mask_builder) - def register_nodes(self, graph: HeteroData) -> None: - self.aoi_mask_builder.fit(graph) - return super().register_nodes(graph) - class HEALPixNodes(BaseNodeBuilder): """Nodes from HEALPix grid. From 8460e06430c14499d61db67d033967484c79f4fd Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 6 Aug 2024 16:05:39 +0000 Subject: [PATCH 118/156] types & docstrings --- src/anemoi/graphs/generate/icosahedral.py | 4 ++-- src/anemoi/graphs/nodes/builder.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 39e99ae..dd0c914 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -52,8 +52,8 @@ def create_icosahedral_nodes( return nx_graph, coords_rad, list(node_ordering) -def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_ordering: list[int]): - +def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_ordering: list[int]) -> nx.DiGraph: + """Creates the networkx graph from the coordinates and the node ordering.""" graph = nx.DiGraph() for ii, coords in enumerate(coords_rad[node_ordering]): node_id = node_ordering[ii] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 5fbd64a..080b61e 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -3,8 +3,10 @@ from abc import abstractmethod from pathlib import Path from typing import Optional +from typing import Tuple from typing import Union +import networkx as nx import numpy as np import torch from anemoi.datasets import open_dataset @@ -294,7 +296,7 @@ def get_coordinates(self) -> torch.Tensor: return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) @abstractmethod - def create_nodes(self) -> np.ndarray: ... + def create_nodes(self) -> Tuple[nx.DiGraph, np.ndarray, list[int]]: ... def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: graph[self.name]["_resolutions"] = self.resolutions @@ -337,7 +339,7 @@ class TriNodes(IcosahedralNodes): It depends on the trimesh Python library. """ - def create_nodes(self) -> np.ndarray: + def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: return create_icosahedral_nodes(resolutions=self.resolutions) @@ -347,7 +349,7 @@ class HexNodes(IcosahedralNodes): It depends on the h3 Python library. """ - def create_nodes(self) -> np.ndarray: + def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: return create_hexagonal_nodes(self.resolutions) @@ -362,7 +364,7 @@ class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): The area of interest mask builder. """ - def create_nodes(self) -> np.ndarray: + def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: return create_icosahedral_nodes(resolutions=self.resolutions, aoi_mask_builder=self.aoi_mask_builder) @@ -377,7 +379,7 @@ class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): The area of interest mask builder. """ - def create_nodes(self) -> np.ndarray: + def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: return create_hexagonal_nodes(self.resolutions, aoi_mask_builder=self.aoi_mask_builder) From c35c4ffad103db1f8d987b51616b329ed39fd9ec Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 7 Aug 2024 09:43:55 +0000 Subject: [PATCH 119/156] fix(lam): icosahedral nodes in lam --- src/anemoi/graphs/generate/icosahedral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index dd0c914..5303efc 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -45,7 +45,7 @@ def create_icosahedral_nodes( if aoi_mask_builder is not None: aoi_mask = aoi_mask_builder.get_mask(coords_rad) - node_ordering = node_ordering[aoi_mask] + node_ordering = node_ordering[aoi_mask[node_ordering]] nx_graph = create_icosahedral_nx_graph_from_coords(coords_rad, node_ordering) From 60118ddb7f891219df6cefb901d5eae3abfdb412 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 7 Aug 2024 11:22:30 +0000 Subject: [PATCH 120/156] fix: style --- src/anemoi/graphs/generate/icosahedral.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 5303efc..ac2c9d0 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -31,11 +31,11 @@ def create_icosahedral_nodes( Returns ------- graph : networkx.Graph - The specified graph (nodes & edges). + The specified graph (only nodes) sorted by latitude and longitude. coords_rad : np.ndarray The node coordinates (not ordered) in radians. node_ordering : list[int] - Order of the nodes in the graph to be sorted by latitude and longitude. + Order of the node coordinates to be sorted by latitude and longitude. """ sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) @@ -47,16 +47,17 @@ def create_icosahedral_nodes( aoi_mask = aoi_mask_builder.get_mask(coords_rad) node_ordering = node_ordering[aoi_mask[node_ordering]] + # Creates the graph, with the nodes sorted by latitude and longitude. nx_graph = create_icosahedral_nx_graph_from_coords(coords_rad, node_ordering) return nx_graph, coords_rad, list(node_ordering) -def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_ordering: list[int]) -> nx.DiGraph: +def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_ordering: np.ndarray) -> nx.DiGraph: """Creates the networkx graph from the coordinates and the node ordering.""" graph = nx.DiGraph() - for ii, coords in enumerate(coords_rad[node_ordering]): - node_id = node_ordering[ii] + for i, coords in enumerate(coords_rad[node_ordering]): + node_id = node_ordering[i] graph.add_node(node_id, hcoords_rad=coords) assert list(graph.nodes.keys()) == list(node_ordering), "Nodes are not correctly added to the graph." From 6c0585efd696cca692c6df3ee62b76895df6f5ed Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 7 Aug 2024 15:29:13 +0000 Subject: [PATCH 121/156] feat: icosahedral & hexagonal lam multiscale edges for lam --- src/anemoi/graphs/edges/builder.py | 57 ++++++------ src/anemoi/graphs/generate/hexagonal.py | 100 +++++++++------------- src/anemoi/graphs/generate/icosahedral.py | 5 +- src/anemoi/graphs/generate/utils.py | 21 +++++ src/anemoi/graphs/nodes/builder.py | 2 +- 5 files changed, 92 insertions(+), 93 deletions(-) create mode 100644 src/anemoi/graphs/generate/utils.py diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index fdf1738..a36fd51 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -16,6 +16,8 @@ from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral from anemoi.graphs.nodes.builder import HexNodes +from anemoi.graphs.nodes.builder import LimitedAreaHexNodes +from anemoi.graphs.nodes.builder import LimitedAreaTriNodes from anemoi.graphs.nodes.builder import TriNodes from anemoi.graphs.utils import get_grid_reference_distance @@ -268,61 +270,52 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor class MultiScaleEdges(BaseEdgeBuilder): """Base class for multi-scale edges in the nodes of a graph.""" + VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes] + def __init__(self, source_name: str, target_name: str, x_hops: int): super().__init__(source_name, target_name) assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same." assert isinstance(x_hops, int), "Number of x_hops must be an integer" assert x_hops > 0, "Number of x_hops must be positive" self.x_hops = x_hops + self.node_type = None - def adjacency_from_tri_nodes(self, source_nodes: NodeStorage): - source_nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph( - source_nodes["_nx_graph"], - resolutions=source_nodes["_resolutions"], + def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: + nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph( + nodes["_nx_graph"], + resolutions=nodes["_resolutions"], x_hops=self.x_hops, - ) # HeteroData refuses to accept None - - adjmat = nx.to_scipy_sparse_array( - source_nodes["_nx_graph"], nodelist=list(range(len(source_nodes["_nx_graph"]))), format="coo" + aoi_mask_builder=nodes.get("_aoi_mask_builder", None), ) - return adjmat - def adjacency_from_hex_nodes(self, source_nodes: NodeStorage): + return nodes - source_nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph( - source_nodes["_nx_graph"], - resolutions=source_nodes["_resolutions"], + def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage: + nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph( + nodes["_nx_graph"], + resolutions=nodes["_resolutions"], x_hops=self.x_hops, ) - adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo") - return adjmat + return nodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): - if self.node_type == TriNodes.__name__: - adjmat = self.adjacency_from_tri_nodes(source_nodes) - elif self.node_type == HexNodes.__name__: - adjmat = self.adjacency_from_hex_nodes(source_nodes) + if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]: + source_nodes = self.add_edges_from_tri_nodes(source_nodes) + elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]: + source_nodes = self.add_edges_from_hex_nodes(source_nodes) else: raise ValueError(f"Invalid node type {self.node_type}") - adjmat = self.post_process_adjmat(source_nodes, adjmat) - - return adjmat + adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo") - def post_process_adjmat(self, nodes: NodeStorage, adjmat): - graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["_node_ordering"])} - sort_func = np.vectorize(graph_sorted.get) - adjmat.row = sort_func(adjmat.row) - adjmat.col = sort_func(adjmat.col) return adjmat def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData: - assert ( - graph[self.source_name].node_type == TriNodes.__name__ - or graph[self.source_name].node_type == HexNodes.__name__ - ), f"{self.__class__.__name__} requires {TriNodes.__name__} or {HexNodes.__name__}." - self.node_type = graph[self.source_name].node_type + valid_node_names = [n.__name__ for n in self.VALID_NODES] + assert ( + self.node_type in valid_node_names + ), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes." return super().update_graph(graph, attrs_config) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 42db0a7..f056b52 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -3,12 +3,14 @@ import h3 import networkx as nx import numpy as np -from sklearn.metrics.pairwise import haversine_distances + +from anemoi.graphs.generate.utils import get_coordinates_ordering +from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder def create_hexagonal_nodes( resolutions: list[int], - area: Optional[dict] = None, + aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, ) -> tuple[nx.Graph, np.ndarray, list[int]]: """Creates a global mesh from a refined icosahedro. @@ -19,70 +21,60 @@ def create_hexagonal_nodes( ---------- resolutions : list[int] Levels of mesh resolution to consider. - area : dict - A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global - mesh. aoi_mask_builder : KNNAreaMaskBuilder, optional KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. Returns ------- graph : networkx.Graph - The specified graph (nodes & edges). + The specified graph (only nodes) sorted by latitude and longitude. coords_rad : np.ndarray The node coordinates (not ordered) in radians. node_ordering : list[int] - Order of the nodes in the graph to be sorted by latitude and longitude. + Order of the node coordinates to be sorted by latitude and longitude. """ - graph = nx.Graph() + nodes = get_nodes_at_resolution(max(resolutions)) - area_kwargs = {"area": area} + coords_rad = np.deg2rad(np.array([h3.h3_to_geo(node) for node in nodes])) - for resolution in resolutions: - graph = add_nodes_for_resolution(graph, resolution, **area_kwargs) + node_ordering = get_coordinates_ordering(coords_rad) - coords = np.deg2rad(np.array([h3.h3_to_geo(node) for node in graph.nodes])) + if aoi_mask_builder is not None: + aoi_mask = aoi_mask_builder.get_mask(coords_rad) + node_ordering = node_ordering[aoi_mask[node_ordering]] - # Sort nodes by latitude and longitude - node_ordering = np.lexsort(coords.T[::-1], axis=0) + graph = create_hexagonal_nx_graph_from_coords(nodes, node_ordering) - return graph, coords, list(node_ordering) + return graph, coords_rad, list(node_ordering) -def add_nodes_for_resolution( - graph: nx.Graph, - resolution: int, - **area_kwargs: Optional[dict], -) -> nx.Graph: +def create_hexagonal_nx_graph_from_coords(nodes: set[str], node_ordering: np.ndarray) -> nx.Graph: """Add all nodes at a specified refinement level to a graph. Parameters ---------- - graph : networkx.Graph - The graph to add the nodes. - resolution : int - The H3 refinement level. It can be an integer from 0 to 15. - area_kwargs: dict - Additional arguments to pass to the get_nodes_at_resolution function. + nodes : list[str] + The set of H3 indexes (nodes). + node_ordering: np.ndarray + Order of the node coordinates to be sorted by latitude and longitude. Returns ------- graph : networkx.Graph The graph with the added nodes. """ + graph = nx.Graph() - nodes = get_nodes_at_resolution(resolution, **area_kwargs) - - for idx in nodes: - graph.add_node(idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(idx))) + for node_pos in node_ordering: + h3_idx = nodes[node_pos] + graph.add_node(h3_idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(h3_idx))) return graph def get_nodes_at_resolution( resolution: int, - area: Optional[dict] = None, -) -> set[str]: +) -> list[str]: """Get nodes at a specified refinement level over the entire globe. If area is not None, it will return the nodes within the specified area @@ -91,28 +83,22 @@ def get_nodes_at_resolution( ---------- resolution : int The H3 refinement level. It can be an integer from 0 to 15. - area : dict - An area as GeoJSON dictionary specifying a polygon. Defaults to None. aoi_mask_builder : KNNAreaMaskBuilder, optional KNNAreaMaskBuilder computes nask to limit the mesh area, by default None. Returns ------- - nodes : set[str] - The set of H3 indexes at the specified resolution level. + nodes : list[str] + The list of H3 indexes at the specified resolution level. """ - nodes = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) - - # TODO: AOI not used in the current implementation. - - return nodes + return list(h3.uncompact(h3.get_res0_indexes(), resolution)) def add_edges_to_nx_graph( graph: nx.Graph, resolutions: list[int], x_hops: int = 1, - depth_children: int = 1, + depth_children: int = 0, ) -> nx.Graph: """Adds the edges to the graph. @@ -130,6 +116,8 @@ def add_edges_to_nx_graph( depth_children : int The number of resolution levels to consider for the connections of children. Defaults to 1, which includes connections up to the next resolution level. + aoi_mask_builder : KNNAreaMaskBuilder + NearestNeighbors with the cloud of points to limit the mesh area, by default None. Returns ------- @@ -138,11 +126,7 @@ def add_edges_to_nx_graph( """ graph = add_neighbour_edges(graph, resolutions, x_hops) - graph = add_edges_to_children( - graph, - resolutions, - depth_children, - ) + graph = add_edges_to_children(graph, resolutions, depth_children) return graph @@ -160,8 +144,8 @@ def add_neighbour_edges( for idx_neighbour in h3.k_ring(idx, k=x_hops) & set(nodes): graph = add_edge( graph, - h3.h3_to_center_child(idx, refinement_levels[-1]), - h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]), + h3.h3_to_center_child(idx, max(refinement_levels)), + h3.h3_to_center_child(idx_neighbour, max(refinement_levels)), ) return graph @@ -191,8 +175,10 @@ def add_edges_to_children( """ if depth_children is None: depth_children = len(refinement_levels) + elif depth_children == 0: + return graph - for i_level, resolution_parent in enumerate(refinement_levels[0:-1]): + for i_level, resolution_parent in enumerate(list(sorted(refinement_levels))[0:-1]): parent_nodes = select_nodes_from_graph_at_resolution(graph, resolution_parent) for parent_idx in parent_nodes: @@ -208,9 +194,11 @@ def add_edges_to_children( return graph -def select_nodes_from_graph_at_resolution(graph: nx.Graph, resolution: int) -> list[int]: - parent_nodes = [node for node in graph.nodes if h3.h3_get_resolution(node) == resolution] - return parent_nodes +def select_nodes_from_graph_at_resolution(graph: nx.Graph, resolution: int) -> set[str]: + """Select nodes from a graph at a specified resolution level.""" + nodes_at_lower_resolution = [n for n in h3.compact(graph.nodes) if h3.h3_get_resolution(n) <= resolution] + nodes_at_resolution = h3.uncompact(nodes_at_lower_resolution, resolution) + return nodes_at_resolution def add_edge( @@ -240,10 +228,6 @@ def add_edge( return graph if source_node_h3_idx != target_node_h3_idx: - source_location = np.deg2rad(h3.h3_to_geo(source_node_h3_idx)) - target_location = np.deg2rad(h3.h3_to_geo(target_node_h3_idx)) - graph.add_edge( - source_node_h3_idx, target_node_h3_idx, weight=haversine_distances([source_location, target_location])[0][1] - ) + graph.add_edge(source_node_h3_idx, target_node_h3_idx) return graph diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index ac2c9d0..20fea30 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -9,6 +9,7 @@ from sklearn.neighbors import BallTree from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad +from anemoi.graphs.generate.utils import get_coordinates_ordering from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder logger = logging.getLogger(__name__) @@ -41,7 +42,7 @@ def create_icosahedral_nodes( coords_rad = cartesian_to_latlon_rad(sphere.vertices) - node_ordering = get_node_ordering(coords_rad) + node_ordering = get_coordinates_ordering(coords_rad) if aoi_mask_builder is not None: aoi_mask = aoi_mask_builder.get_mask(coords_rad) @@ -79,7 +80,7 @@ def add_edges_to_nx_graph( resolutions: list[int], x_hops: int = 1, aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, -) -> None: +) -> nx.DiGraph: """Adds the edges to the graph. This method adds multi-scale connections to the existing graph. The corresponfing nodes or vertices diff --git a/src/anemoi/graphs/generate/utils.py b/src/anemoi/graphs/generate/utils.py new file mode 100644 index 0000000..82b4960 --- /dev/null +++ b/src/anemoi/graphs/generate/utils.py @@ -0,0 +1,21 @@ +import numpy as np + + +def get_coordinates_ordering(coords: np.ndarray) -> np.ndarray: + """Sort node coordinates by latitude and longitude. + + Parameters + ---------- + coords : np.ndarray of shape (N, 2) + The node coordinates, with the latitude in the first column and the + longitude in the second column. + + Returns + ------- + np.ndarray + The order of the node coordinates to be sorted by latitude and longitude. + """ + index_latitude = np.argsort(coords[:, 1]) + index_longitude = np.argsort(coords[index_latitude][:, 0])[::-1] + node_ordering = np.arange(coords.shape[0])[index_latitude][index_longitude] + return node_ordering diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 080b61e..82b3a74 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -350,7 +350,7 @@ class HexNodes(IcosahedralNodes): """ def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_hexagonal_nodes(self.resolutions) + return create_hexagonal_nodes(resolutions=self.resolutions) class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): From be7265333d740097944ea45b1a23e7cfa9986ba0 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 8 Aug 2024 13:02:46 +0000 Subject: [PATCH 122/156] fixs(docs): typo --- src/anemoi/graphs/generate/hexagonal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index f056b52..2dd7105 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -12,7 +12,7 @@ def create_hexagonal_nodes( resolutions: list[int], aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, ) -> tuple[nx.Graph, np.ndarray, list[int]]: - """Creates a global mesh from a refined icosahedro. + """Creates a global mesh from a refined icosahedron. This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each refinement level, a hexagon cell (nodes) has 7 child cells (aperture 7). From dcbd4edc0eab03f93b9a48a04640070110caf52b Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 09:19:38 +0000 Subject: [PATCH 123/156] fix: remove redundant code --- src/anemoi/graphs/generate/icosahedral.py | 9 --------- src/anemoi/graphs/generate/utils.py | 1 + 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 20fea30..3110c5f 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -66,15 +66,6 @@ def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_orderin return graph -def get_node_ordering(coords_rad: np.ndarray) -> np.ndarray: - """Get the node ordering to sort the nodes by latitude and longitude.""" - # Get indices to sort points by lon & lat in radians. - index_latitude = np.argsort(coords_rad[:, 1]) - index_longitude = np.argsort(coords_rad[index_latitude][:, 0])[::-1] - node_ordering = np.arange(coords_rad.shape[0])[index_latitude][index_longitude] - return node_ordering - - def add_edges_to_nx_graph( graph: nx.DiGraph, resolutions: list[int], diff --git a/src/anemoi/graphs/generate/utils.py b/src/anemoi/graphs/generate/utils.py index 82b4960..df72a2a 100644 --- a/src/anemoi/graphs/generate/utils.py +++ b/src/anemoi/graphs/generate/utils.py @@ -15,6 +15,7 @@ def get_coordinates_ordering(coords: np.ndarray) -> np.ndarray: np.ndarray The order of the node coordinates to be sorted by latitude and longitude. """ + # Get indices to sort points by lon & lat in radians. index_latitude = np.argsort(coords[:, 1]) index_longitude = np.argsort(coords[index_latitude][:, 0])[::-1] node_ordering = np.arange(coords.shape[0])[index_latitude][index_longitude] From e2cb2b455a2cec81deba0b22574989f34a915d33 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 11:16:04 +0000 Subject: [PATCH 124/156] refactor: remove edge attr computation during graph creation --- src/anemoi/graphs/generate/icosahedral.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 3110c5f..49251f2 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -5,7 +5,6 @@ import networkx as nx import numpy as np import trimesh -from sklearn.metrics.pairwise import haversine_distances from sklearn.neighbors import BallTree from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad @@ -100,7 +99,7 @@ def add_edges_to_nx_graph( node_neighbours = get_neighbours_within_hops(sphere, x_hops, valid_nodes=list(graph.nodes)) for idx_node, idx_neighbours in node_neighbours.items(): - add_neigbours_edges(graph, vertices_rad, idx_node, idx_neighbours) + add_neigbours_edges(graph, idx_node, idx_neighbours) tree = BallTree(vertices_rad, metric="haversine") @@ -123,9 +122,7 @@ def add_edges_to_nx_graph( _, vertex_mapping_index = tree.query(r_vertices_rad, k=1) for idx_node, idx_neighbours in node_neighbours.items(): - add_neigbours_edges( - graph, r_vertices_rad, idx_node, idx_neighbours, vertex_mapping_index=vertex_mapping_index - ) + add_neigbours_edges(graph, idx_node, idx_neighbours, vertex_mapping_index=vertex_mapping_index) return graph @@ -168,7 +165,6 @@ def get_neighbours_within_hops( def add_neigbours_edges( graph: nx.Graph, - vertices: np.ndarray, node_idx: int, neighbour_indices: Iterable[int], self_loops: bool = False, @@ -180,8 +176,6 @@ def add_neigbours_edges( ---------- graph : nx.Graph The graph. - vertices : np.ndarray - A 2D array of shape (num_vertices, 2) with the planar coordinates of the mesh, in radians. node_idx : int The node considered. neighbours : list[int] @@ -195,10 +189,6 @@ def add_neigbours_edges( if not self_loops and node_idx == neighbour_idx: # no self-loops continue - location_node = vertices[node_idx] - location_neighbour = vertices[neighbour_idx] - edge_length = haversine_distances([location_neighbour, location_node])[0][1] - if vertex_mapping_index is not None: # Use the same method to add edge in all spheres node_neighbour = vertex_mapping_index[neighbour_idx][0] @@ -208,4 +198,4 @@ def add_neigbours_edges( # add edge to the graph if node in graph and node_neighbour in graph: - graph.add_edge(node_neighbour, node, weight=edge_length) + graph.add_edge(node_neighbour, node) From 74533d4745d43ce0dbe99b033aac8ad23d34f7bf Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 11:25:30 +0000 Subject: [PATCH 125/156] refactor: split builder.py in several files --- src/anemoi/graphs/nodes/__init__.py | 20 +- src/anemoi/graphs/nodes/builder.py | 468 ------------------ src/anemoi/graphs/nodes/builders/base.py | 105 ++++ src/anemoi/graphs/nodes/builders/from_file.py | 159 ++++++ .../graphs/nodes/builders/from_healpix.py | 95 ++++ .../builders/from_refined_icosahedron.py | 138 ++++++ 6 files changed, 507 insertions(+), 478 deletions(-) delete mode 100644 src/anemoi/graphs/nodes/builder.py create mode 100644 src/anemoi/graphs/nodes/builders/base.py create mode 100644 src/anemoi/graphs/nodes/builders/from_file.py create mode 100644 src/anemoi/graphs/nodes/builders/from_healpix.py create mode 100644 src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 2214d37..991c3c2 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,13 +1,13 @@ -from .builder import HEALPixNodes -from .builder import HexNodes -from .builder import LimitedAreaHEALPixNodes -from .builder import LimitedAreaHexNodes -from .builder import LimitedAreaNPZFileNodes -from .builder import LimitedAreaTriNodes -from .builder import LimitedAreaZarrDatasetNodes -from .builder import NPZFileNodes -from .builder import TriNodes -from .builder import ZarrDatasetNodes +from .builders.from_file import LimitedAreaNPZFileNodes +from .builders.from_file import LimitedAreaZarrDatasetNodes +from .builders.from_file import NPZFileNodes +from .builders.from_file import ZarrDatasetNodes +from .builders.from_healpix import HEALPixNodes +from .builders.from_healpix import LimitedAreaHEALPixNodes +from .builders.from_refined_icosahedron import HexNodes +from .builders.from_refined_icosahedron import LimitedAreaHexNodes +from .builders.from_refined_icosahedron import LimitedAreaTriNodes +from .builders.from_refined_icosahedron import TriNodes __all__ = [ "ZarrDatasetNodes", diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py deleted file mode 100644 index 82b3a74..0000000 --- a/src/anemoi/graphs/nodes/builder.py +++ /dev/null @@ -1,468 +0,0 @@ -import logging -from abc import ABC -from abc import abstractmethod -from pathlib import Path -from typing import Optional -from typing import Tuple -from typing import Union - -import networkx as nx -import numpy as np -import torch -from anemoi.datasets import open_dataset -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from torch_geometric.data import HeteroData - -from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes -from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes -from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder - -LOGGER = logging.getLogger(__name__) - - -class BaseNodeBuilder(ABC): - """Base class for node builders. - - The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. - - Attributes - ---------- - name : str - name of the nodes, key for the nodes in the HeteroData graph object. - aoi_mask_builder : KNNAreaMaskBuilder - The area of interest mask builder, if any. Defaults to None. - """ - - def __init__(self, name: str) -> None: - self.name = name - self.aoi_mask_builder = None - - def register_nodes(self, graph: HeteroData) -> None: - """Register nodes in the graph. - - Parameters - ---------- - graph : HeteroData - The graph to register the nodes. - """ - graph[self.name].x = self.get_coordinates() - graph[self.name].node_type = type(self).__name__ - return graph - - def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = None) -> HeteroData: - """Register attributes in the nodes of the graph specified. - - Parameters - ---------- - graph : HeteroData - The graph to register the attributes. - config : DotDict - The configuration of the attributes. - - Returns - ------- - HeteroData - The graph with the registered attributes. - """ - for attr_name, attr_config in config.items(): - graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) - - return graph - - @abstractmethod - def get_coordinates(self) -> torch.Tensor: ... - - def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch.Tensor: - """Reshape latitude and longitude coordinates. - - Parameters - ---------- - latitudes : np.ndarray of shape (num_nodes, ) - Latitude coordinates, in degrees. - longitudes : np.ndarray of shape (num_nodes, ) - Longitude coordinates, in degrees. - - Returns - ------- - torch.Tensor of shape (num_nodes, 2) - A 2D tensor with the coordinates, in radians. - """ - coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) - coords = np.deg2rad(coords) - return torch.tensor(coords, dtype=torch.float32) - - def update_graph(self, graph: HeteroData, attr_config: Optional[DotDict] = None) -> HeteroData: - """Update the graph with new nodes. - - Parameters - ---------- - graph : HeteroData - Input graph. - attr_config : DotDict - The configuration of the attributes. - - Returns - ------- - HeteroData - The graph with new nodes included. - """ - graph = self.register_nodes(graph) - - if attr_config is None: - return graph - - graph = self.register_attributes(graph, attr_config) - - return graph - - -class ZarrDatasetNodes(BaseNodeBuilder): - """Nodes from Zarr dataset. - - Attributes - ---------- - dataset : zarr.core.Array - The dataset. - - Methods - ------- - get_coordinates() - Get the lat-lon coordinates of the nodes. - register_nodes(graph, name) - Register the nodes in the graph. - register_attributes(graph, name, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, name, attr_config) - Update the graph with new nodes and attributes. - """ - - def __init__(self, dataset: DotDict, name: str) -> None: - LOGGER.info("Reading the dataset from %s.", dataset) - self.dataset = open_dataset(dataset) - super().__init__(name) - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (num_nodes, 2) - A 2D tensor with the coordinates, in radians. - """ - return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) - - -class LimitedAreaZarrDatasetNodes(ZarrDatasetNodes): - """Nodes from Zarr dataset.""" - - def __init__( - self, name: str, lam_dataset: str, forcing_dataset: str, thinning: int = 1, adjust: str = "all" - ) -> None: - dataset_config = { - "cutout": [{"dataset": lam_dataset, "thinning": thinning}, {"dataset": forcing_dataset}], - "adjust": adjust, - } - super().__init__(dataset_config, name) - self.n_cutout, self.n_other = self.dataset.grids - - def register_attributes(self, graph: HeteroData, config: DotDict) -> None: - # this is a mask to cutout the LAM area - graph[self.name]["cutout"] = torch.tensor([True] * self.n_cutout + [False] * self.n_other, dtype=bool).reshape( - (-1, 1) - ) - return super().register_attributes(graph, config) - - -class NPZFileNodes(BaseNodeBuilder): - """Nodes from NPZ defined grids. - - Attributes - ---------- - resolution : str - The resolution of the grid. - grid_definition_path : str - Path to the folder containing the grid definition files. - grid_definition : dict[str, np.ndarray] - The grid definition. - - Methods - ------- - get_coordinates() - Get the lat-lon coordinates of the nodes. - register_nodes(graph, name) - Register the nodes in the graph. - register_attributes(graph, name, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, name, attr_config) - Update the graph with new nodes and attributes. - """ - - def __init__(self, resolution: str, grid_definition_path: str, name: str) -> None: - """Initialize the NPZFileNodes builder. - - The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. - - Parameters - ---------- - resolution : str - The resolution of the grid. - grid_definition_path : str - Path to the folder containing the grid definition files. - """ - self.resolution = resolution - self.grid_definition_path = grid_definition_path - self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") - super().__init__(name) - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (num_nodes, 2) - A 2D tensor with the coordinates, in radians. - """ - coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) - return coords - - -class LimitedAreaNPZFileNodes(NPZFileNodes): - """Nodes from NPZ defined grids using an area of interest.""" - - def __init__( - self, - resolution: str, - grid_definition_path: str, - name: str, - reference_node_name: str, - mask_attr_name: str, - margin_radius_km: float = 100.0, - ) -> None: - - self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) - - super().__init__(resolution, grid_definition_path, name) - - def register_nodes(self, graph: HeteroData) -> None: - self.aoi_mask_builder.fit(graph) - return super().register_nodes(graph) - - def get_coordinates(self) -> np.ndarray: - coords = super().get_coordinates() - - LOGGER.info( - "Limiting the processor mesh to a radius of %.2f km from the output mesh.", - self.aoi_mask_builder.margin_radius_km, - ) - aoi_mask = self.aoi_mask_builder.get_mask(coords) - - LOGGER.info("Dropping %d nodes from the processor mesh.", len(aoi_mask) - aoi_mask.sum()) - coords = coords[aoi_mask] - - return coords - - -class IcosahedralNodes(BaseNodeBuilder, ABC): - """Nodes based on iterative refinements of an icosahedron. - - Attributes - ---------- - resolution : list[int] | int - Refinement level of the mesh. - """ - - def __init__( - self, - resolution: Union[int, list[int]], - name: str, - ) -> None: - if isinstance(resolution, int): - self.resolutions = list(range(resolution + 1)) - else: - self.resolutions = resolution - - super().__init__(name) - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (num_nodes, 2) - A 2D tensor with the coordinates, in radians. - """ - self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() - return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) - - @abstractmethod - def create_nodes(self) -> Tuple[nx.DiGraph, np.ndarray, list[int]]: ... - - def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: - graph[self.name]["_resolutions"] = self.resolutions - graph[self.name]["_nx_graph"] = self.nx_graph - graph[self.name]["_node_ordering"] = self.node_ordering - graph[self.name]["_aoi_mask_builder"] = self.aoi_mask_builder - return super().register_attributes(graph, config) - - -class LimitedAreaIcosahedralNodes(IcosahedralNodes): - """Nodes based on iterative refinements of an icosahedron using an area of interest. - - Attributes - ---------- - aoi_mask_builder : KNNAreaMaskBuilder - The area of interest mask builder. - """ - - def __init__( - self, - resolution: int | list[int], - name: str, - reference_node_name: str, - mask_attr_name: str, - margin_radius_km: float = 100.0, - ) -> None: - - super().__init__(resolution, name) - - self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) - - def register_nodes(self, graph: HeteroData) -> None: - self.aoi_mask_builder.fit(graph) - return super().register_nodes(graph) - - -class TriNodes(IcosahedralNodes): - """Nodes based on iterative refinements of an icosahedron. - - It depends on the trimesh Python library. - """ - - def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_icosahedral_nodes(resolutions=self.resolutions) - - -class HexNodes(IcosahedralNodes): - """Nodes based on iterative refinements of an icosahedron. - - It depends on the h3 Python library. - """ - - def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_hexagonal_nodes(resolutions=self.resolutions) - - -class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): - """Nodes based on iterative refinements of an icosahedron using an area of interest. - - It depends on the trimesh Python library. - - Parameters - ---------- - aoi_mask_builder: KNNAreaMaskBuilder - The area of interest mask builder. - """ - - def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_icosahedral_nodes(resolutions=self.resolutions, aoi_mask_builder=self.aoi_mask_builder) - - -class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): - """Nodes based on iterative refinements of an icosahedron using an area of interest. - - It depends on the h3 Python library. - - Parameters - ---------- - aoi_mask_builder: KNNAreaMaskBuilder - The area of interest mask builder. - """ - - def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_hexagonal_nodes(self.resolutions, aoi_mask_builder=self.aoi_mask_builder) - - -class HEALPixNodes(BaseNodeBuilder): - """Nodes from HEALPix grid. - - HEALPix is an acronym for Hierarchical Equal Area isoLatitude Pixelization of a sphere. - - Attributes - ---------- - resolution : int - The resolution of the grid. - - Methods - ------- - get_coordinates() - Get the lat-lon coordinates of the nodes. - register_nodes(graph, name) - Register the nodes in the graph. - register_attributes(graph, name, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, name, attr_config) - Update the graph with new nodes and attributes. - """ - - def __init__(self, resolution: int, name: str) -> None: - """Initialize the HEALPixNodes builder.""" - self.resolution = resolution - super().__init__(name) - - assert isinstance(resolution, int), "Resolution must be an integer." - assert resolution > 0, "Resolution must be positive." - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (num_nodes, 2) - Coordinates of the nodes, in radians. - """ - import healpy as hp - - spatial_res_degrees = hp.nside2resol(2**self.resolution, arcmin=True) / 60 - LOGGER.info(f"Creating HEALPix nodes with resolution {spatial_res_degrees:.2} deg.") - - npix = hp.nside2npix(2**self.resolution) - hpxlon, hpxlat = hp.pix2ang(2**self.resolution, range(npix), nest=True, lonlat=True) - - return self.reshape_coords(hpxlat, hpxlon) - - -class LimitedAreaHEALPixNodes(HEALPixNodes): - """Nodes from HEALPix grid using an area of interest.""" - - def __init__( - self, - resolution: str, - name: str, - reference_node_name: str, - mask_attr_name: str, - margin_radius_km: float = 100.0, - ) -> None: - - self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) - - super().__init__(resolution, name) - - def register_nodes(self, graph: HeteroData) -> None: - self.aoi_mask_builder.fit(graph) - return super().register_nodes(graph) - - def get_coordinates(self) -> np.ndarray: - coords = super().get_coordinates() - - LOGGER.info( - 'Limiting the "%s" nodes to a radius of %.2f km from the nodes of interest.', - self.name, - self.aoi_mask_builder.margin_radius_km, - ) - aoi_mask = self.aoi_mask_builder.get_mask(coords) - - LOGGER.info('Masking out %d nodes from "%s".', len(aoi_mask) - aoi_mask.sum(), self.name) - coords = coords[aoi_mask] - - return coords diff --git a/src/anemoi/graphs/nodes/builders/base.py b/src/anemoi/graphs/nodes/builders/base.py new file mode 100644 index 0000000..a56552e --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/base.py @@ -0,0 +1,105 @@ +from abc import ABC +from abc import abstractmethod +from typing import Optional + +import numpy as np +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + + +class BaseNodeBuilder(ABC): + """Base class for node builders. + + The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. + + Attributes + ---------- + name : str + name of the nodes, key for the nodes in the HeteroData graph object. + aoi_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder, if any. Defaults to None. + """ + + def __init__(self, name: str) -> None: + self.name = name + self.aoi_mask_builder = None + + def register_nodes(self, graph: HeteroData) -> None: + """Register nodes in the graph. + + Parameters + ---------- + graph : HeteroData + The graph to register the nodes. + """ + graph[self.name].x = self.get_coordinates() + graph[self.name].node_type = type(self).__name__ + return graph + + def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = None) -> HeteroData: + """Register attributes in the nodes of the graph specified. + + Parameters + ---------- + graph : HeteroData + The graph to register the attributes. + config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with the registered attributes. + """ + for attr_name, attr_config in config.items(): + graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) + + return graph + + @abstractmethod + def get_coordinates(self) -> torch.Tensor: ... + + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch.Tensor: + """Reshape latitude and longitude coordinates. + + Parameters + ---------- + latitudes : np.ndarray of shape (num_nodes, ) + Latitude coordinates, in degrees. + longitudes : np.ndarray of shape (num_nodes, ) + Longitude coordinates, in degrees. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) + coords = np.deg2rad(coords) + return torch.tensor(coords, dtype=torch.float32) + + def update_graph(self, graph: HeteroData, attr_config: Optional[DotDict] = None) -> HeteroData: + """Update the graph with new nodes. + + Parameters + ---------- + graph : HeteroData + Input graph. + attr_config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with new nodes included. + """ + graph = self.register_nodes(graph) + + if attr_config is None: + return graph + + graph = self.register_attributes(graph, attr_config) + + return graph diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py new file mode 100644 index 0000000..5695b0f --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -0,0 +1,159 @@ +import logging +from pathlib import Path + +import numpy as np +import torch +from anemoi.datasets import open_dataset +from anemoi.utils.config import DotDict +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder +from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder + +LOGGER = logging.getLogger(__name__) + + +class ZarrDatasetNodes(BaseNodeBuilder): + """Nodes from Zarr dataset. + + Attributes + ---------- + dataset : zarr.core.Array + The dataset. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, dataset: DotDict, name: str) -> None: + LOGGER.info("Reading the dataset from %s.", dataset) + self.dataset = open_dataset(dataset) + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) + + +class LimitedAreaZarrDatasetNodes(ZarrDatasetNodes): + """Nodes from Zarr dataset.""" + + def __init__( + self, name: str, lam_dataset: str, forcing_dataset: str, thinning: int = 1, adjust: str = "all" + ) -> None: + dataset_config = { + "cutout": [{"dataset": lam_dataset, "thinning": thinning}, {"dataset": forcing_dataset}], + "adjust": adjust, + } + super().__init__(dataset_config, name) + self.n_cutout, self.n_other = self.dataset.grids + + def register_attributes(self, graph: HeteroData, config: DotDict) -> None: + # this is a mask to cutout the LAM area + graph[self.name]["cutout"] = torch.tensor([True] * self.n_cutout + [False] * self.n_other, dtype=bool).reshape( + (-1, 1) + ) + return super().register_attributes(graph, config) + + +class NPZFileNodes(BaseNodeBuilder): + """Nodes from NPZ defined grids. + + Attributes + ---------- + resolution : str + The resolution of the grid. + grid_definition_path : str + Path to the folder containing the grid definition files. + grid_definition : dict[str, np.ndarray] + The grid definition. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, resolution: str, grid_definition_path: str, name: str) -> None: + """Initialize the NPZFileNodes builder. + + The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. + + Parameters + ---------- + resolution : str + The resolution of the grid. + grid_definition_path : str + Path to the folder containing the grid definition files. + """ + self.resolution = resolution + self.grid_definition_path = grid_definition_path + self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) + return coords + + +class LimitedAreaNPZFileNodes(NPZFileNodes): + """Nodes from NPZ defined grids using an area of interest.""" + + def __init__( + self, + resolution: str, + grid_definition_path: str, + name: str, + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + ) -> None: + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + super().__init__(resolution, grid_definition_path, name) + + def register_nodes(self, graph: HeteroData) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph) + + def get_coordinates(self) -> np.ndarray: + coords = super().get_coordinates() + + LOGGER.info( + "Limiting the processor mesh to a radius of %.2f km from the output mesh.", + self.aoi_mask_builder.margin_radius_km, + ) + aoi_mask = self.aoi_mask_builder.get_mask(coords) + + LOGGER.info("Dropping %d nodes from the processor mesh.", len(aoi_mask) - aoi_mask.sum()) + coords = coords[aoi_mask] + + return coords diff --git a/src/anemoi/graphs/nodes/builders/from_healpix.py b/src/anemoi/graphs/nodes/builders/from_healpix.py new file mode 100644 index 0000000..9b26669 --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_healpix.py @@ -0,0 +1,95 @@ +import logging + +import numpy as np +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder +from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder + +LOGGER = logging.getLogger(__name__) + + +class HEALPixNodes(BaseNodeBuilder): + """Nodes from HEALPix grid. + + HEALPix is an acronym for Hierarchical Equal Area isoLatitude Pixelization of a sphere. + + Attributes + ---------- + resolution : int + The resolution of the grid. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, resolution: int, name: str) -> None: + """Initialize the HEALPixNodes builder.""" + self.resolution = resolution + super().__init__(name) + + assert isinstance(resolution, int), "Resolution must be an integer." + assert resolution > 0, "Resolution must be positive." + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + Coordinates of the nodes, in radians. + """ + import healpy as hp + + spatial_res_degrees = hp.nside2resol(2**self.resolution, arcmin=True) / 60 + LOGGER.info(f"Creating HEALPix nodes with resolution {spatial_res_degrees:.2} deg.") + + npix = hp.nside2npix(2**self.resolution) + hpxlon, hpxlat = hp.pix2ang(2**self.resolution, range(npix), nest=True, lonlat=True) + + return self.reshape_coords(hpxlat, hpxlon) + + +class LimitedAreaHEALPixNodes(HEALPixNodes): + """Nodes from HEALPix grid using an area of interest.""" + + def __init__( + self, + resolution: str, + name: str, + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + ) -> None: + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + super().__init__(resolution, name) + + def register_nodes(self, graph: HeteroData) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph) + + def get_coordinates(self) -> np.ndarray: + coords = super().get_coordinates() + + LOGGER.info( + 'Limiting the "%s" nodes to a radius of %.2f km from the nodes of interest.', + self.name, + self.aoi_mask_builder.margin_radius_km, + ) + aoi_mask = self.aoi_mask_builder.get_mask(coords) + + LOGGER.info('Masking out %d nodes from "%s".', len(aoi_mask) - aoi_mask.sum(), self.name) + coords = coords[aoi_mask] + + return coords diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py new file mode 100644 index 0000000..5ebe46a --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -0,0 +1,138 @@ +import logging +from abc import ABC +from abc import abstractmethod +from typing import Tuple +from typing import Union + +import networkx as nx +import numpy as np +import torch +from anemoi.utils.config import DotDict +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes +from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder +from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder + +LOGGER = logging.getLogger(__name__) + + +class IcosahedralNodes(BaseNodeBuilder, ABC): + """Nodes based on iterative refinements of an icosahedron. + + Attributes + ---------- + resolution : list[int] | int + Refinement level of the mesh. + """ + + def __init__( + self, + resolution: Union[int, list[int]], + name: str, + ) -> None: + if isinstance(resolution, int): + self.resolutions = list(range(resolution + 1)) + else: + self.resolutions = resolution + + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() + return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) + + @abstractmethod + def create_nodes(self) -> Tuple[nx.DiGraph, np.ndarray, list[int]]: ... + + def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: + graph[self.name]["_resolutions"] = self.resolutions + graph[self.name]["_nx_graph"] = self.nx_graph + graph[self.name]["_node_ordering"] = self.node_ordering + graph[self.name]["_aoi_mask_builder"] = self.aoi_mask_builder + return super().register_attributes(graph, config) + + +class LimitedAreaIcosahedralNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + Attributes + ---------- + aoi_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def __init__( + self, + resolution: int | list[int], + name: str, + reference_node_name: str, + mask_attr_name: str, + margin_radius_km: float = 100.0, + ) -> None: + + super().__init__(resolution, name) + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + def register_nodes(self, graph: HeteroData) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph) + + +class TriNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron. + + It depends on the trimesh Python library. + """ + + def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: + return create_icosahedral_nodes(resolutions=self.resolutions) + + +class HexNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron. + + It depends on the h3 Python library. + """ + + def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: + return create_hexagonal_nodes(resolutions=self.resolutions) + + +class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + It depends on the trimesh Python library. + + Parameters + ---------- + aoi_mask_builder: KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: + return create_icosahedral_nodes(resolutions=self.resolutions, aoi_mask_builder=self.aoi_mask_builder) + + +class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + It depends on the h3 Python library. + + Parameters + ---------- + aoi_mask_builder: KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: + return create_hexagonal_nodes(self.resolutions, aoi_mask_builder=self.aoi_mask_builder) From 045ab0957feca939f26542babe678a504b24ad39 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 11:26:59 +0000 Subject: [PATCH 126/156] fix: rename node builder class --- src/anemoi/graphs/nodes/__init__.py | 4 ++-- src/anemoi/graphs/nodes/builders/from_file.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 991c3c2..b5c919a 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,5 +1,5 @@ +from .builders.from_file import CutOutZarrDatasetNodes from .builders.from_file import LimitedAreaNPZFileNodes -from .builders.from_file import LimitedAreaZarrDatasetNodes from .builders.from_file import NPZFileNodes from .builders.from_file import ZarrDatasetNodes from .builders.from_healpix import HEALPixNodes @@ -16,7 +16,7 @@ "HexNodes", "HEALPixNodes", "LimitedAreaHEALPixNodes", - "LimitedAreaZarrDatasetNodes", + "CutOutZarrDatasetNodes", "LimitedAreaNPZFileNodes", "LimitedAreaTriNodes", "LimitedAreaHexNodes", diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index 5695b0f..0793e84 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -49,7 +49,7 @@ def get_coordinates(self) -> torch.Tensor: return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) -class LimitedAreaZarrDatasetNodes(ZarrDatasetNodes): +class CutOutZarrDatasetNodes(ZarrDatasetNodes): """Nodes from Zarr dataset.""" def __init__( From 18dc2a3044dbbd28316838784a2c97477711bc56 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 11:53:41 +0000 Subject: [PATCH 127/156] fix: test imports --- tests/nodes/test_healpix.py | 4 ++-- tests/nodes/test_hex_nodes.py | 2 +- tests/nodes/test_npz.py | 2 +- tests/nodes/test_tri_nodes.py | 2 +- tests/nodes/test_zarr.py | 20 ++++++++++---------- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/nodes/test_healpix.py b/tests/nodes/test_healpix.py index 3c6883c..3293c1c 100644 --- a/tests/nodes/test_healpix.py +++ b/tests/nodes/test_healpix.py @@ -4,8 +4,8 @@ from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights -from anemoi.graphs.nodes.builder import BaseNodeBuilder -from anemoi.graphs.nodes.builder import HEALPixNodes +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder +from anemoi.graphs.nodes.builders.from_healpix import HEALPixNodes @pytest.mark.parametrize("resolution", [2, 5, 7]) diff --git a/tests/nodes/test_hex_nodes.py b/tests/nodes/test_hex_nodes.py index 03e00da..54c2b6b 100644 --- a/tests/nodes/test_hex_nodes.py +++ b/tests/nodes/test_hex_nodes.py @@ -3,7 +3,7 @@ from torch_geometric.data import HeteroData from anemoi.graphs.nodes import HexNodes -from anemoi.graphs.nodes.builder import BaseNodeBuilder +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder @pytest.mark.parametrize("resolution", [0, 2]) diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 95d09c0..21b767a 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -4,7 +4,7 @@ from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights -from anemoi.graphs.nodes.builder import NPZFileNodes +from anemoi.graphs.nodes.builders.from_file import NPZFileNodes @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) diff --git a/tests/nodes/test_tri_nodes.py b/tests/nodes/test_tri_nodes.py index af1af69..4f522ce 100644 --- a/tests/nodes/test_tri_nodes.py +++ b/tests/nodes/test_tri_nodes.py @@ -3,7 +3,7 @@ from torch_geometric.data import HeteroData from anemoi.graphs.nodes import TriNodes -from anemoi.graphs.nodes.builder import BaseNodeBuilder +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder @pytest.mark.parametrize("resolution", [0, 2]) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index e7c98cc..90610ac 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -3,30 +3,30 @@ import zarr from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import builder from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builders import from_file def test_init(mocker, mock_zarr_dataset): """Test ZarrDatasetNodes initialization.""" - mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") - assert isinstance(node_builder, builder.BaseNodeBuilder) - assert isinstance(node_builder, builder.ZarrDatasetNodes) + assert isinstance(node_builder, from_file.BaseNodeBuilder) + assert isinstance(node_builder, from_file.ZarrDatasetNodes) def test_fail_init(): """Test ZarrDatasetNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): - builder.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") + from_file.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") def test_register_nodes(mocker, mock_zarr_dataset): """Test ZarrDatasetNodes register correctly the nodes.""" - mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") graph = HeteroData() graph = node_builder.register_nodes(graph) @@ -40,8 +40,8 @@ def test_register_nodes(mocker, mock_zarr_dataset): @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_attributes(mocker, graph_with_nodes: HeteroData, attr_class): """Test ZarrDatasetNodes register correctly the weights.""" - mocker.patch.object(builder, "open_dataset", return_value=None) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=None) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, config) From fee21606a8e21e86b8051057a61d58b44c79f84b Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 11:59:11 +0000 Subject: [PATCH 128/156] Updated CHANGELOG.md --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e4747c..46f600c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ Keep it human-readable, your future self will thank you! ## [Unreleased] ### Added +- New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. +- New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. +- New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). +- Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. - HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere - added downstream-ci pipeline From 6a19a625d04351b0391599b8dd32e30d744373be Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 14:25:30 +0000 Subject: [PATCH 129/156] fix: imports --- src/anemoi/graphs/edges/builder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index a36fd51..60bc8c8 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -15,10 +15,10 @@ from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.generate import hexagonal from anemoi.graphs.generate import icosahedral -from anemoi.graphs.nodes.builder import HexNodes -from anemoi.graphs.nodes.builder import LimitedAreaHexNodes -from anemoi.graphs.nodes.builder import LimitedAreaTriNodes -from anemoi.graphs.nodes.builder import TriNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import HexNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaHexNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) From b4df29640ebfe3a4496f7d997cb2611d9db9f07c Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 14:30:48 +0000 Subject: [PATCH 130/156] fix: move masks.py to generate/ --- src/anemoi/graphs/generate/hexagonal.py | 2 +- src/anemoi/graphs/generate/icosahedral.py | 2 +- src/anemoi/graphs/{nodes => generate}/masks.py | 0 src/anemoi/graphs/nodes/builders/__init__.py | 0 src/anemoi/graphs/nodes/builders/from_file.py | 2 +- src/anemoi/graphs/nodes/builders/from_healpix.py | 2 +- src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py | 2 +- 7 files changed, 5 insertions(+), 5 deletions(-) rename src/anemoi/graphs/{nodes => generate}/masks.py (100%) create mode 100644 src/anemoi/graphs/nodes/builders/__init__.py diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 2dd7105..051efcb 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -4,8 +4,8 @@ import networkx as nx import numpy as np +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder from anemoi.graphs.generate.utils import get_coordinates_ordering -from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder def create_hexagonal_nodes( diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 49251f2..c8ca750 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -7,9 +7,9 @@ import trimesh from sklearn.neighbors import BallTree +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad from anemoi.graphs.generate.utils import get_coordinates_ordering -from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder logger = logging.getLogger(__name__) diff --git a/src/anemoi/graphs/nodes/masks.py b/src/anemoi/graphs/generate/masks.py similarity index 100% rename from src/anemoi/graphs/nodes/masks.py rename to src/anemoi/graphs/generate/masks.py diff --git a/src/anemoi/graphs/nodes/builders/__init__.py b/src/anemoi/graphs/nodes/builders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index 0793e84..1a2eb54 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -7,8 +7,8 @@ from anemoi.utils.config import DotDict from torch_geometric.data import HeteroData +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder from anemoi.graphs.nodes.builders.base import BaseNodeBuilder -from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder LOGGER = logging.getLogger(__name__) diff --git a/src/anemoi/graphs/nodes/builders/from_healpix.py b/src/anemoi/graphs/nodes/builders/from_healpix.py index 9b26669..a1b8c46 100644 --- a/src/anemoi/graphs/nodes/builders/from_healpix.py +++ b/src/anemoi/graphs/nodes/builders/from_healpix.py @@ -4,8 +4,8 @@ import torch from torch_geometric.data import HeteroData +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder from anemoi.graphs.nodes.builders.base import BaseNodeBuilder -from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder LOGGER = logging.getLogger(__name__) diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py index 5ebe46a..6a4aaae 100644 --- a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -12,8 +12,8 @@ from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder from anemoi.graphs.nodes.builders.base import BaseNodeBuilder -from anemoi.graphs.nodes.masks import KNNAreaMaskBuilder LOGGER = logging.getLogger(__name__) From 5d54f2f5391f1015893e279b0cf41b3614068343 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 9 Aug 2024 14:39:33 +0000 Subject: [PATCH 131/156] refactor: update resolutions argument to resolution --- src/anemoi/graphs/generate/hexagonal.py | 8 ++++---- src/anemoi/graphs/generate/icosahedral.py | 8 ++++---- .../graphs/nodes/builders/from_refined_icosahedron.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 051efcb..65ad3c3 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -9,7 +9,7 @@ def create_hexagonal_nodes( - resolutions: list[int], + resolution: int, aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, ) -> tuple[nx.Graph, np.ndarray, list[int]]: """Creates a global mesh from a refined icosahedron. @@ -19,8 +19,8 @@ def create_hexagonal_nodes( Parameters ---------- - resolutions : list[int] - Levels of mesh resolution to consider. + resolution : int + Level of mesh resolution to consider. aoi_mask_builder : KNNAreaMaskBuilder, optional KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. @@ -33,7 +33,7 @@ def create_hexagonal_nodes( node_ordering : list[int] Order of the node coordinates to be sorted by latitude and longitude. """ - nodes = get_nodes_at_resolution(max(resolutions)) + nodes = get_nodes_at_resolution(resolution) coords_rad = np.deg2rad(np.array([h3.h3_to_geo(node) for node in nodes])) diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index c8ca750..3cddaa2 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -15,7 +15,7 @@ def create_icosahedral_nodes( - resolutions: list[int], aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None + resolution: int, aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None ) -> tuple[nx.DiGraph, np.ndarray, list[int]]: """Creates a global mesh following AIFS strategy. @@ -23,8 +23,8 @@ def create_icosahedral_nodes( Parameters ---------- - resolutions : list[int] - Levels of mesh resolution to consider. + resolution : int + Level of mesh resolution to consider. aoi_mask_builder : KNNAreaMaskBuilder KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. @@ -37,7 +37,7 @@ def create_icosahedral_nodes( node_ordering : list[int] Order of the node coordinates to be sorted by latitude and longitude. """ - sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) + sphere = trimesh.creation.icosphere(subdivisions=resolution, radius=1.0) coords_rad = cartesian_to_latlon_rad(sphere.vertices) diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py index 6a4aaae..11cb9f2 100644 --- a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -95,7 +95,7 @@ class TriNodes(IcosahedralNodes): """ def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_icosahedral_nodes(resolutions=self.resolutions) + return create_icosahedral_nodes(resolution=max(self.resolutions)) class HexNodes(IcosahedralNodes): @@ -105,7 +105,7 @@ class HexNodes(IcosahedralNodes): """ def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_hexagonal_nodes(resolutions=self.resolutions) + return create_hexagonal_nodes(resolution=max(self.resolutions)) class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): @@ -120,7 +120,7 @@ class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): """ def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_icosahedral_nodes(resolutions=self.resolutions, aoi_mask_builder=self.aoi_mask_builder) + return create_icosahedral_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): @@ -135,4 +135,4 @@ class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): """ def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: - return create_hexagonal_nodes(self.resolutions, aoi_mask_builder=self.aoi_mask_builder) + return create_hexagonal_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) From b9ae468008162edf52cad9ffcbf6bb996d1c542d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 19 Aug 2024 14:10:47 +0000 Subject: [PATCH 132/156] Merge 'develop' branch into 7-local-area-modelling-graphs --- .../workflows/changelog-release-update.yml | 35 ++++++++++++++++ .github/workflows/ci.yml | 17 ++++++-- .github/workflows/python-publish.yml | 42 ++++--------------- .github/workflows/python-pull-request.yml | 23 ++++++++++ CHANGELOG.md | 15 ++++--- src/anemoi/graphs/create.py | 8 ++-- src/anemoi/graphs/edges/attributes.py | 9 ++-- src/anemoi/graphs/edges/builder.py | 7 ++-- src/anemoi/graphs/edges/directional.py | 4 +- src/anemoi/graphs/generate/hexagonal.py | 8 ++-- src/anemoi/graphs/generate/icosahedral.py | 11 ++--- src/anemoi/graphs/nodes/attributes.py | 9 ++-- src/anemoi/graphs/nodes/builders/base.py | 5 +-- .../builders/from_refined_icosahedron.py | 16 +++---- src/anemoi/graphs/utils.py | 6 +-- 15 files changed, 129 insertions(+), 86 deletions(-) create mode 100644 .github/workflows/changelog-release-update.yml create mode 100644 .github/workflows/python-pull-request.yml diff --git a/.github/workflows/changelog-release-update.yml b/.github/workflows/changelog-release-update.yml new file mode 100644 index 0000000..17d9525 --- /dev/null +++ b/.github/workflows/changelog-release-update.yml @@ -0,0 +1,35 @@ +# .github/workflows/update-changelog.yaml +name: "Update Changelog" + +on: + release: + types: [released] + workflow_dispatch: ~ + +permissions: + pull-requests: write + contents: write + +jobs: + update: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.event.release.target_commitish }} + + - name: Update Changelog + uses: stefanzweifel/changelog-updater-action@v1 + with: + latest-version: ${{ github.event.release.tag_name }} + heading-text: ${{ github.event.release.name }} + + - name: Create Pull Request + uses: peter-evans/create-pull-request@v6 + with: + branch: docs/changelog-update-${{ github.event.release.tag_name }} + title: '[Changelog] Update to ${{ github.event.release.tag_name }}' + add-paths: | + CHANGELOG.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4ffa63a..f9e9f91 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,12 +8,17 @@ on: - 'develop' tags-ignore: - '**' - paths: - - "src/**" - - "tests/**" + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" # Trigger the workflow on pull request - pull_request: ~ + pull_request: + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" # Trigger the workflow manually workflow_dispatch: ~ @@ -21,6 +26,10 @@ on: # Trigger after public PR approved for CI pull_request_target: types: [labeled] + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" jobs: # Run CI including downstream packages on self-hosted runners diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 816e1d8..2cb554a 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -4,50 +4,22 @@ name: Upload Python Package on: - - push: {} - release: types: [created] jobs: quality: - name: Code QA - runs-on: ubuntu-latest - steps: - - run: sudo apt-get install -y pandoc # Needed by sphinx for notebooks - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: 3.x - - uses: pre-commit/action@v3.0.1 - env: - SKIP: no-commit-to-branch + uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-precommit-run.yml@v2 + with: + skip-hooks: "no-commit-to-branch" checks: strategy: - fail-fast: false matrix: - platform: ["ubuntu-latest", "macos-latest"] - python-version: ["3.10"] - - name: Python ${{ matrix.python-version }} on ${{ matrix.platform }} - runs-on: ${{ matrix.platform }} - - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install - run: | - pip install -e .[all,tests] - pip freeze - - - name: Tests - run: pytest + python-version: ["3.9", "3.10"] + uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 + with: + python-version: ${{ matrix.python-version }} deploy: needs: [checks, quality] diff --git a/.github/workflows/python-pull-request.yml b/.github/workflows/python-pull-request.yml new file mode 100644 index 0000000..0ebecb1 --- /dev/null +++ b/.github/workflows/python-pull-request.yml @@ -0,0 +1,23 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Code Quality checks for PRs + +on: + push: + pull_request_target: + types: [opened, synchronize, reopened] + +jobs: + quality: + uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-precommit-run.yml@v2 + with: + skip-hooks: "no-commit-to-branch" + + checks: + strategy: + matrix: + python-version: ["3.9", "3.10"] + uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 + with: + python-version: ${{ matrix.python-version }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b2c19b..7aaa207 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,13 +17,18 @@ Keep it human-readable, your future self will thank you! - Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. - HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere - added downstream-ci pipeline and cd-pypi reusable workflow +- Changelog release updater ### Changed -- Fix bug in graph cleaning method -- Fix `anemoi-graphs create`. Config argument is cast to a Path. -- Fix GraphCreator().clean() to not iterate over a dictionary that may change size during iterations. -- Fix missing binary dependency -- **Fix**: Updated `get_raw_values` method in `AreaWeights` to ensure compatibility with `scipy.spatial.SphericalVoronoi` by converting `latitudes` and `longitudes` to NumPy arrays before passing them to the `latlon_rad_to_cartesian` function. This resolves an issue where the function would fail if passed Torch Tensors directly. +- fix: added support for Python3.9. +- fix: bug in graph cleaning method +- fix: `anemoi-graphs create` CLI argument is casted to a Path. +- ci: fix missing binary dependency in ci-config.yaml +- fix: Updated `get_raw_values` method in `AreaWeights` to ensure compatibility with `scipy.spatial.SphericalVoronoi` by converting `latitudes` and `longitudes` to NumPy arrays before passing them to the `latlon_rad_to_cartesian` function. This resolves an issue where the function would fail if passed Torch Tensors directly. +- ci: Reusable workflows for push, PR, and releases +- ci: ignore docs for downstream ci +- ci: changed Changelog action to create PR +- ci: fixes and permissions on changelog updater ### Removed diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index d52b37f..9f10c81 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -1,8 +1,8 @@ +from __future__ import annotations + import logging from itertools import chain from pathlib import Path -from typing import Optional -from typing import Union import torch from anemoi.utils.config import DotDict @@ -17,7 +17,7 @@ class GraphCreator: def __init__( self, - config: Union[Path, DotDict], + config: str | Path | DotDict, ): if isinstance(config, Path) or isinstance(config, str): self.config = DotDict.from_file(config) @@ -91,7 +91,7 @@ def save(self, graph: HeteroData, save_path: Path, overwrite: bool = False) -> N else: LOGGER.info("Graph already exists. Use overwrite=True to overwrite.") - def create(self, save_path: Optional[Path] = None, overwrite: bool = False) -> HeteroData: + def create(self, save_path: Path | None = None, overwrite: bool = False) -> HeteroData: """Create the graph and save it to the output path. Parameters diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 9a8d6d8..c65a402 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod -from typing import Optional import numpy as np import torch @@ -17,7 +18,7 @@ class BaseEdgeAttribute(ABC, NormalizerMixin): """Base class for edge attributes.""" - def __init__(self, norm: Optional[str] = None) -> None: + def __init__(self, norm: str | None = None) -> None: self.norm = norm @abstractmethod @@ -69,7 +70,7 @@ class EdgeDirection(BaseEdgeAttribute): Compute directional attributes. """ - def __init__(self, norm: Optional[str] = None, luse_rotated_features: bool = True) -> None: + def __init__(self, norm: str | None = None, luse_rotated_features: bool = True) -> None: super().__init__(norm) self.luse_rotated_features = luse_rotated_features @@ -115,7 +116,7 @@ class EdgeLength(BaseEdgeAttribute): Compute edge lengths attributes. """ - def __init__(self, norm: Optional[str] = None, invert: bool = False) -> None: + def __init__(self, norm: str | None = None, invert: bool = False) -> None: super().__init__(norm) self.invert = invert diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 60bc8c8..1b22b2e 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod -from typing import Optional import networkx as nx import numpy as np @@ -101,7 +102,7 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) return graph - def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: + def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData: """Update the graph with the edges. Parameters @@ -216,7 +217,7 @@ def __init__(self, source_name: str, target_name: str, cutoff_factor: float): assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor - def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] = None): + def get_cutoff_radius(self, graph: HeteroData, mask_attr: torch.Tensor | None = None): """Compute the cut-off radius. The cut-off radius is computed as the product of the target nodes reference distance and the cut-off factor. diff --git a/src/anemoi/graphs/edges/directional.py b/src/anemoi/graphs/edges/directional.py index 9c7cdea..322f37e 100644 --- a/src/anemoi/graphs/edges/directional.py +++ b/src/anemoi/graphs/edges/directional.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations import numpy as np from scipy.spatial.transform import Rotation @@ -28,7 +28,7 @@ def get_rotation_from_unit_vecs(points: np.ndarray, reference: np.ndarray) -> Ro return Rotation.from_rotvec(np.transpose(v_unit * theta)) -def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: Optional[np.ndarray] = None) -> np.ndarray: +def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: np.ndarray | None = None) -> np.ndarray: """Compute the direction of the edge joining the nodes considered. Parameters diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index 65ad3c3..c8d46b0 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations import h3 import networkx as nx @@ -10,7 +10,7 @@ def create_hexagonal_nodes( resolution: int, - aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, + aoi_mask_builder: KNNAreaMaskBuilder | None = None, ) -> tuple[nx.Graph, np.ndarray, list[int]]: """Creates a global mesh from a refined icosahedron. @@ -83,8 +83,6 @@ def get_nodes_at_resolution( ---------- resolution : int The H3 refinement level. It can be an integer from 0 to 15. - aoi_mask_builder : KNNAreaMaskBuilder, optional - KNNAreaMaskBuilder computes nask to limit the mesh area, by default None. Returns ------- @@ -154,7 +152,7 @@ def add_neighbour_edges( def add_edges_to_children( graph: nx.Graph, refinement_levels: tuple[int], - depth_children: Optional[int] = None, + depth_children: int | None = None, ) -> nx.Graph: """Adds edges to the children of the nodes at the specified resolution levels. diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index 3cddaa2..4415c09 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import logging from collections.abc import Iterable -from typing import Optional import networkx as nx import numpy as np @@ -15,7 +16,7 @@ def create_icosahedral_nodes( - resolution: int, aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None + resolution: int, aoi_mask_builder: KNNAreaMaskBuilder | None = None ) -> tuple[nx.DiGraph, np.ndarray, list[int]]: """Creates a global mesh following AIFS strategy. @@ -69,7 +70,7 @@ def add_edges_to_nx_graph( graph: nx.DiGraph, resolutions: list[int], x_hops: int = 1, - aoi_mask_builder: Optional[KNNAreaMaskBuilder] = None, + aoi_mask_builder: KNNAreaMaskBuilder | None = None, ) -> nx.DiGraph: """Adds the edges to the graph. @@ -128,7 +129,7 @@ def add_edges_to_nx_graph( def get_neighbours_within_hops( - tri_mesh: trimesh.Trimesh, x_hops: int, valid_nodes: Optional[list[int]] = None + tri_mesh: trimesh.Trimesh, x_hops: int, valid_nodes: list[int] | None = None ) -> dict[int, set[int]]: """Get the neigbour connections in the graph. @@ -168,7 +169,7 @@ def add_neigbours_edges( node_idx: int, neighbour_indices: Iterable[int], self_loops: bool = False, - vertex_mapping_index: Optional[np.ndarray] = None, + vertex_mapping_index: np.ndarray | None = None, ) -> None: """Adds the edges of one node to its neighbours. diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index 040f134..e0ed1d8 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod -from typing import Optional import numpy as np import torch @@ -18,7 +19,7 @@ class BaseWeights(ABC, NormalizerMixin): """Base class for the weights of the nodes.""" - def __init__(self, norm: Optional[str] = None) -> None: + def __init__(self, norm: str | None = None) -> None: self.norm = norm @abstractmethod @@ -92,9 +93,7 @@ class AreaWeights(BaseWeights): Compute the area attributes for each node. """ - def __init__( - self, norm: Optional[str] = None, radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0]) - ) -> None: + def __init__(self, norm: str | None = None, radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])) -> None: super().__init__(norm) self.radius = radius self.centre = centre diff --git a/src/anemoi/graphs/nodes/builders/base.py b/src/anemoi/graphs/nodes/builders/base.py index a56552e..e0a688a 100644 --- a/src/anemoi/graphs/nodes/builders/base.py +++ b/src/anemoi/graphs/nodes/builders/base.py @@ -1,6 +1,5 @@ from abc import ABC from abc import abstractmethod -from typing import Optional import numpy as np import torch @@ -38,7 +37,7 @@ def register_nodes(self, graph: HeteroData) -> None: graph[self.name].node_type = type(self).__name__ return graph - def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = None) -> HeteroData: + def register_attributes(self, graph: HeteroData, config: DotDict | None = None) -> HeteroData: """Register attributes in the nodes of the graph specified. Parameters @@ -80,7 +79,7 @@ def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch coords = np.deg2rad(coords) return torch.tensor(coords, dtype=torch.float32) - def update_graph(self, graph: HeteroData, attr_config: Optional[DotDict] = None) -> HeteroData: + def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData: """Update the graph with new nodes. Parameters diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py index 11cb9f2..db08854 100644 --- a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -1,8 +1,8 @@ +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod -from typing import Tuple -from typing import Union import networkx as nx import numpy as np @@ -29,7 +29,7 @@ class IcosahedralNodes(BaseNodeBuilder, ABC): def __init__( self, - resolution: Union[int, list[int]], + resolution: int | list[int], name: str, ) -> None: if isinstance(resolution, int): @@ -51,7 +51,7 @@ def get_coordinates(self) -> torch.Tensor: return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) @abstractmethod - def create_nodes(self) -> Tuple[nx.DiGraph, np.ndarray, list[int]]: ... + def create_nodes(self) -> tuple[nx.DiGraph, np.ndarray, list[int]]: ... def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: graph[self.name]["_resolutions"] = self.resolutions @@ -94,7 +94,7 @@ class TriNodes(IcosahedralNodes): It depends on the trimesh Python library. """ - def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: return create_icosahedral_nodes(resolution=max(self.resolutions)) @@ -104,7 +104,7 @@ class HexNodes(IcosahedralNodes): It depends on the h3 Python library. """ - def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: return create_hexagonal_nodes(resolution=max(self.resolutions)) @@ -119,7 +119,7 @@ class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): The area of interest mask builder. """ - def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: return create_icosahedral_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) @@ -134,5 +134,5 @@ class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): The area of interest mask builder. """ - def create_nodes(self) -> Tuple[nx.Graph, np.ndarray, list[int]]: + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: return create_hexagonal_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py index 8999bc6..c895426 100644 --- a/src/anemoi/graphs/utils.py +++ b/src/anemoi/graphs/utils.py @@ -1,11 +1,11 @@ -from typing import Optional +from __future__ import annotations import numpy as np import torch from sklearn.neighbors import NearestNeighbors -def get_nearest_neighbour(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] = None) -> NearestNeighbors: +def get_nearest_neighbour(coords_rad: torch.Tensor, mask: torch.Tensor | None = None) -> NearestNeighbors: """Get NearestNeighbour object fitted to coordinates. Parameters @@ -32,7 +32,7 @@ def get_nearest_neighbour(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] return nearest_neighbour -def get_grid_reference_distance(coords_rad: torch.Tensor, mask: Optional[torch.Tensor] = None) -> float: +def get_grid_reference_distance(coords_rad: torch.Tensor, mask: torch.Tensor | None = None) -> float: """Get the reference distance of the grid. It is the maximum distance of a node in the mesh with respect to its nearest neighbour. From ddef15411f9356c441336cf53a24a096c3b01de0 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 19 Aug 2024 14:23:21 +0000 Subject: [PATCH 133/156] fix: import annotations (py3.9) --- src/anemoi/graphs/nodes/builders/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/anemoi/graphs/nodes/builders/base.py b/src/anemoi/graphs/nodes/builders/base.py index e0a688a..b0e670f 100644 --- a/src/anemoi/graphs/nodes/builders/base.py +++ b/src/anemoi/graphs/nodes/builders/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from abc import abstractmethod From 6acfc320375be651c2b260da30f8b06e354da6dc Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 19 Aug 2024 15:04:34 +0000 Subject: [PATCH 134/156] tests: new tests for CutOutZarrDatasetNodes --- tests/conftest.py | 12 ++++++- tests/nodes/test_cutout_nodes.py | 56 ++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 tests/nodes/test_cutout_nodes.py diff --git a/tests/conftest.py b/tests/conftest.py index b801614..2fcc824 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,10 +11,12 @@ class MockZarrDataset: """Mock Zarr dataset with latitudes and longitudes attributes.""" - def __init__(self, latitudes, longitudes): + def __init__(self, latitudes, longitudes, grids=None): self.latitudes = latitudes self.longitudes = longitudes self.num_nodes = len(latitudes) + if grids is not None: + self.grids = grids @pytest.fixture @@ -24,6 +26,14 @@ def mock_zarr_dataset() -> MockZarrDataset: return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1]) +@pytest.fixture +def mock_zarr_dataset_cutout() -> MockZarrDataset: + """Mock zarr dataset with nodes.""" + coords = 2 * torch.pi * np.array([[lat, lon] for lat in lats for lon in lons]) + grids = int(0.3 * len(coords)), int(0.7 * len(coords)) + return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1], grids=grids) + + @pytest.fixture def mock_grids_path(tmp_path) -> tuple[str, int]: """Mock grid_definition_path with files for 3 resolutions.""" diff --git a/tests/nodes/test_cutout_nodes.py b/tests/nodes/test_cutout_nodes.py new file mode 100644 index 0000000..a18424d --- /dev/null +++ b/tests/nodes/test_cutout_nodes.py @@ -0,0 +1,56 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builders import from_file + + +def test_init(mocker, mock_zarr_dataset_cutout): + """Test CutOutZarrDatasetNodes initialization.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.CutOutZarrDatasetNodes( + forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + ) + + assert isinstance(node_builder, from_file.BaseNodeBuilder) + assert isinstance(node_builder, from_file.CutOutZarrDatasetNodes) + + +def test_fail_init(): + """Test CutOutZarrDatasetNodes initialization with invalid resolution.""" + with pytest.raises(TypeError): + from_file.CutOutZarrDatasetNodes("global_dataset.zarr", name="test_nodes") + + +def test_register_nodes(mocker, mock_zarr_dataset_cutout): + """Test CutOutZarrDatasetNodes register correctly the nodes.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.CutOutZarrDatasetNodes( + forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + ) + graph = HeteroData() + + graph = node_builder.register_nodes(graph) + + assert graph["test_nodes"].x is not None + assert isinstance(graph["test_nodes"].x, torch.Tensor) + assert graph["test_nodes"].x.shape == (node_builder.dataset.num_nodes, 2) + assert graph["test_nodes"].node_type == "CutOutZarrDatasetNodes" + + +@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) +def test_register_attributes(mocker, mock_zarr_dataset_cutout, graph_with_nodes: HeteroData, attr_class): + """Test CutOutZarrDatasetNodes register correctly the weights.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.CutOutZarrDatasetNodes( + forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + ) + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} + + graph = node_builder.register_attributes(graph_with_nodes, config) + + assert graph["test_nodes"]["test_attr"] is not None + assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) + assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] From 36aa407777a1c5e90f015b42da6dc23b3e8f2496 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 19 Aug 2024 15:43:24 +0000 Subject: [PATCH 135/156] test: new test for KNNAreaMaskBuilder --- src/anemoi/graphs/generate/masks.py | 9 ++++++ tests/conftest.py | 2 ++ tests/generate/test_masks.py | 48 +++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+) create mode 100644 tests/generate/test_masks.py diff --git a/src/anemoi/graphs/generate/masks.py b/src/anemoi/graphs/generate/masks.py index e1a4128..ad75c5e 100644 --- a/src/anemoi/graphs/generate/masks.py +++ b/src/anemoi/graphs/generate/masks.py @@ -32,6 +32,8 @@ class KNNAreaMaskBuilder: """ def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask_attr_name: str = None): + assert isinstance(margin_radius_km, (int, float)), "The margin radius must be a number." + assert margin_radius_km > 0, "The margin radius must be positive." self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) self.margin_radius_km = margin_radius_km @@ -40,9 +42,16 @@ def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask def fit(self, graph: HeteroData): """Fit the KNN model to the nodes of interest.""" + assert ( + self.reference_node_name in graph.node_types + ), f'Reference node "{self.reference_node_name}" not found in the graph.' reference_mask_str = self.reference_node_name + coords_rad = graph[self.reference_node_name].x.numpy() if self.mask_attr_name is not None: + assert ( + self.mask_attr_name in graph[self.reference_node_name].node_attrs() + ), f'Mask attribute "{self.mask_attr_name}" not found in the reference nodes.' mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() coords_rad = coords_rad[mask] reference_mask_str += f" ({self.mask_attr_name})" diff --git a/tests/conftest.py b/tests/conftest.py index 2fcc824..23208c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,7 @@ def graph_with_nodes() -> HeteroData: coords = np.array([[lat, lon] for lat in lats for lon in lons]) graph = HeteroData() graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph["test_nodes"].mask = torch.tensor([True] * len(coords)) return graph @@ -59,6 +60,7 @@ def graph_nodes_and_edges() -> HeteroData: coords = np.array([[lat, lon] for lat in lats for lon in lons]) graph = HeteroData() graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph["test_nodes"].mask = torch.tensor([True] * len(coords)) graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) return graph diff --git a/tests/generate/test_masks.py b/tests/generate/test_masks.py new file mode 100644 index 0000000..651bdb7 --- /dev/null +++ b/tests/generate/test_masks.py @@ -0,0 +1,48 @@ +import pytest +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder + + +def test_init(): + """Test KNNAreaMaskBuilder initialization.""" + mask_builder1 = KNNAreaMaskBuilder("nodes") + mask_builder2 = KNNAreaMaskBuilder("nodes", margin_radius_km=120) + mask_builder3 = KNNAreaMaskBuilder("nodes", mask_attr_name="mask") + mask_builder4 = KNNAreaMaskBuilder("nodes", margin_radius_km=120, mask_attr_name="mask") + + assert isinstance(mask_builder1, KNNAreaMaskBuilder) + assert isinstance(mask_builder2, KNNAreaMaskBuilder) + assert isinstance(mask_builder3, KNNAreaMaskBuilder) + assert isinstance(mask_builder4, KNNAreaMaskBuilder) + + assert isinstance(mask_builder1.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder2.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder3.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder4.nearest_neighbour, NearestNeighbors) + + +@pytest.mark.parametrize("margin", [-1, "120", None]) +def test_fail_init_wrong_margin(margin: int): + """Test KNNAreaMaskBuilder initialization with invalid margin.""" + with pytest.raises(AssertionError): + KNNAreaMaskBuilder("nodes", margin_radius_km=margin) + + +@pytest.mark.parametrize("mask", [None, "mask"]) +def test_fit(graph_with_nodes: HeteroData, mask: str): + """Test KNNAreaMaskBuilder fit.""" + mask_builder = KNNAreaMaskBuilder("test_nodes", mask_attr_name=mask) + assert not hasattr(mask_builder.nearest_neighbour, "n_samples_fit_") + + mask_builder.fit(graph_with_nodes) + + assert mask_builder.nearest_neighbour.n_samples_fit_ == graph_with_nodes["test_nodes"].num_nodes + + +def test_fit_fail(graph_with_nodes): + """Test KNNAreaMaskBuilder fit with wrong graph.""" + mask_builder = KNNAreaMaskBuilder("wrong_nodes") + with pytest.raises(AssertionError): + mask_builder.fit(graph_with_nodes) From d6a8a20b79ccdf32cd971c7b5b0ee9ea401b6946 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 21 Aug 2024 07:55:47 +0000 Subject: [PATCH 136/156] feat: mask_attr_name should be optional --- src/anemoi/graphs/generate/masks.py | 4 +++- src/anemoi/graphs/nodes/builders/from_file.py | 6 ++++-- src/anemoi/graphs/nodes/builders/from_healpix.py | 6 ++++-- .../graphs/nodes/builders/from_refined_icosahedron.py | 4 ++-- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/anemoi/graphs/generate/masks.py b/src/anemoi/graphs/generate/masks.py index ad75c5e..d624347 100644 --- a/src/anemoi/graphs/generate/masks.py +++ b/src/anemoi/graphs/generate/masks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import numpy as np @@ -31,7 +33,7 @@ class KNNAreaMaskBuilder: Get the mask for the nodes based on the distance to the reference nodes. """ - def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask_attr_name: str = None): + def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask_attr_name: str | None = None): assert isinstance(margin_radius_km, (int, float)), "The margin radius must be a number." assert margin_radius_km > 0, "The margin radius must be positive." diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index 1a2eb54..7fa7661 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from pathlib import Path @@ -130,9 +132,9 @@ def __init__( self, resolution: str, grid_definition_path: str, - name: str, reference_node_name: str, - mask_attr_name: str, + name: str, + mask_attr_name: str | None = None, margin_radius_km: float = 100.0, ) -> None: diff --git a/src/anemoi/graphs/nodes/builders/from_healpix.py b/src/anemoi/graphs/nodes/builders/from_healpix.py index a1b8c46..a8ef080 100644 --- a/src/anemoi/graphs/nodes/builders/from_healpix.py +++ b/src/anemoi/graphs/nodes/builders/from_healpix.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import numpy as np @@ -65,9 +67,9 @@ class LimitedAreaHEALPixNodes(HEALPixNodes): def __init__( self, resolution: str, - name: str, reference_node_name: str, - mask_attr_name: str, + name: str, + mask_attr_name: str | None = None, margin_radius_km: float = 100.0, ) -> None: diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py index db08854..5f39063 100644 --- a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -73,9 +73,9 @@ class LimitedAreaIcosahedralNodes(IcosahedralNodes): def __init__( self, resolution: int | list[int], - name: str, reference_node_name: str, - mask_attr_name: str, + name: str, + mask_attr_name: str | None = None, margin_radius_km: float = 100.0, ) -> None: From 7a53363d5d67c78cc63f6d6d5f1ec5b0850a0cdd Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 21 Aug 2024 09:12:50 +0000 Subject: [PATCH 137/156] refactor: KNNAreaMAskBuilder --- src/anemoi/graphs/generate/masks.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/anemoi/graphs/generate/masks.py b/src/anemoi/graphs/generate/masks.py index d624347..1b804f0 100644 --- a/src/anemoi/graphs/generate/masks.py +++ b/src/anemoi/graphs/generate/masks.py @@ -27,6 +27,8 @@ class KNNAreaMaskBuilder: Methods ------- + fit_coords(coords_rad: np.ndarray) + Fit the KNN model to the coordinates in radians. fit(graph: HeteroData) Fit the KNN model to the reference nodes. get_mask(coords_rad: np.ndarray) -> np.ndarray @@ -42,12 +44,11 @@ def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask self.reference_node_name = reference_node_name self.mask_attr_name = mask_attr_name - def fit(self, graph: HeteroData): - """Fit the KNN model to the nodes of interest.""" + def get_reference_coords(self, graph: HeteroData) -> np.ndarray: + """Retrive coordinates from the reference nodes.""" assert ( self.reference_node_name in graph.node_types ), f'Reference node "{self.reference_node_name}" not found in the graph.' - reference_mask_str = self.reference_node_name coords_rad = graph[self.reference_node_name].x.numpy() if self.mask_attr_name is not None: @@ -56,15 +57,30 @@ def fit(self, graph: HeteroData): ), f'Mask attribute "{self.mask_attr_name}" not found in the reference nodes.' mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() coords_rad = coords_rad[mask] + + return coords_rad + + def fit_coords(self, coords_rad: np.ndarray): + """Fit the KNN model to the coordinates in radians.""" + self.nearest_neighbour.fit(coords_rad) + + def fit(self, graph: HeteroData): + """Fit the KNN model to the nodes of interest.""" + # Prepare string for logging + reference_mask_str = self.reference_node_name + if self.mask_attr_name is not None: reference_mask_str += f" ({self.mask_attr_name})" + # Fit to the reference nodes + coords_rad = self.get_reference_coords(graph) + self.fit_coords(coords_rad) + LOGGER.info( 'Fitting %s with %d reference nodes from "%s".', self.__class__.__name__, len(coords_rad), reference_mask_str, ) - self.nearest_neighbour.fit(coords_rad) def get_mask(self, coords_rad: np.ndarray) -> np.ndarray: """Compute a mask based on the distance to the reference nodes.""" From 19e38b0ee73ec68521c76526b78d5f1e2b1d36d1 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 2 Sep 2024 14:49:09 +0100 Subject: [PATCH 138/156] docs: remove duplication --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 285dae0..5fac590 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,6 @@ Keep it human-readable, your future self will thank you! - New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). - Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. - HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere -- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere. - Inspection tools: interactive plots, and distribution plots of edge & node attributes. - Graph description print in the console. - CLI entry point: 'anemoi-graphs inspect ...'. From 5d247deecef2ccdac9766fc34c685a833e3b7ff5 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 2 Sep 2024 14:58:19 +0100 Subject: [PATCH 139/156] refactor: remove second return --- src/anemoi/graphs/edges/builder.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 1b22b2e..d502bee 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -119,10 +119,8 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) - """ graph = self.register_edges(graph) - if attrs_config is None: - return graph - - graph = self.register_attributes(graph, attrs_config) + if attrs_config is not None: + graph = self.register_attributes(graph, attrs_config) return graph From ed6f960207bfce509d580cec8f1eaf93d48e139a Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 3 Sep 2024 10:18:05 +0000 Subject: [PATCH 140/156] docs: output folder missing --- docs/usage/getting_started.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/usage/getting_started.rst b/docs/usage/getting_started.rst index b4998e1..4c322a2 100644 --- a/docs/usage/getting_started.rst +++ b/docs/usage/getting_started.rst @@ -54,7 +54,7 @@ following command: .. code:: console - $ anemoi-graphs inspect graph.pt + $ anemoi-graphs inspect graph.pt output_plots This will generate the following graph: From af1e5207ca91f199648fb2c2d17ca74b9d11f273 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 11 Sep 2024 15:18:50 +0100 Subject: [PATCH 141/156] docs: add PR to changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12b9a9b..04a3679 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...HEAD) ### Added -- New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. +- New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30) - New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. - New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). - Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. From 3b65701055bcf0d4d935d2cd18c5641d2111ea3a Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 11 Sep 2024 15:35:15 +0100 Subject: [PATCH 142/156] docs: added missing PRs in changelog --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04a3679..43a8688 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,9 +12,9 @@ Keep it human-readable, your future self will thank you! ### Added - New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30) -- New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. -- New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). -- Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. +- New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. (#30) +- New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). (#30) +- Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. (#30) ## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03 From 524610a4cea89b79ccc78de45c379cb3556b6542 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 12 Sep 2024 08:55:19 +0000 Subject: [PATCH 143/156] fix: address @mchantry's comments --- src/anemoi/graphs/generate/hexagonal.py | 10 ++++------ src/anemoi/graphs/generate/icosahedral.py | 8 ++++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hexagonal.py index fa947ea..a70eef5 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hexagonal.py @@ -10,7 +10,7 @@ def create_hexagonal_nodes( resolution: int, - aoi_mask_builder: KNNAreaMaskBuilder | None = None, + area_mask_builder: KNNAreaMaskBuilder | None = None, ) -> tuple[nx.Graph, np.ndarray, list[int]]: """Creates a global mesh from a refined icosahedron. @@ -21,7 +21,7 @@ def create_hexagonal_nodes( ---------- resolution : int Level of mesh resolution to consider. - aoi_mask_builder : KNNAreaMaskBuilder, optional + area_mask_builder : KNNAreaMaskBuilder, optional KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. Returns @@ -39,8 +39,8 @@ def create_hexagonal_nodes( node_ordering = get_coordinates_ordering(coords_rad) - if aoi_mask_builder is not None: - aoi_mask = aoi_mask_builder.get_mask(coords_rad) + if area_mask_builder is not None: + aoi_mask = area_mask_builder.get_mask(coords_rad) node_ordering = node_ordering[aoi_mask[node_ordering]] graph = create_hexagonal_nx_graph_from_coords(nodes, node_ordering) @@ -77,8 +77,6 @@ def get_nodes_at_resolution( ) -> list[str]: """Get nodes at a specified refinement level over the entire globe. - If area is not None, it will return the nodes within the specified area - Parameters ---------- resolution : int diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/icosahedral.py index ff85fd7..993cec2 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/icosahedral.py @@ -67,7 +67,7 @@ def add_edges_to_nx_graph( graph: nx.DiGraph, resolutions: list[int], x_hops: int = 1, - aoi_mask_builder: KNNAreaMaskBuilder | None = None, + area_mask_builder: KNNAreaMaskBuilder | None = None, ) -> nx.DiGraph: """Adds the edges to the graph. @@ -82,7 +82,7 @@ def add_edges_to_nx_graph( Levels of mesh refinement levels to consider. x_hops : int, optional Number of hops between 2 nodes to consider them neighbours, by default 1. - aoi_mask_builder : KNNAreaMaskBuilder + area_mask_builder : KNNAreaMaskBuilder NearestNeighbors with the cloud of points to limit the mesh area, by default None. Returns @@ -110,8 +110,8 @@ def add_edges_to_nx_graph( r_vertices_rad = cartesian_to_latlon_rad(r_sphere.vertices) # Limit area of mesh points. - if aoi_mask_builder is not None: - aoi_mask = aoi_mask_builder.get_mask(vertices_rad) + if area_mask_builder is not None: + aoi_mask = area_mask_builder.get_mask(vertices_rad) valid_nodes = np.where(aoi_mask)[0] else: valid_nodes = None From 24ff1fcc2831e502cf7317491e41b0380f1f4b8f Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 12 Sep 2024 14:13:01 +0000 Subject: [PATCH 144/156] fix: rename according to the refinement --- src/anemoi/graphs/create.py | 2 +- src/anemoi/graphs/edges/builder.py | 8 ++++---- .../generate/{hexagonal.py => hex_icosahedron.py} | 6 +++--- .../generate/{icosahedral.py => tri_icosahedron.py} | 8 ++++---- .../nodes/builders/from_refined_icosahedron.py | 12 ++++++------ 5 files changed, 18 insertions(+), 18 deletions(-) rename src/anemoi/graphs/generate/{hexagonal.py => hex_icosahedron.py} (97%) rename src/anemoi/graphs/generate/{icosahedral.py => tri_icosahedron.py} (96%) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 9f10c81..51de9a4 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -41,7 +41,7 @@ def generate_graph(self) -> HeteroData: graph, nodes_cfg.get("attributes", {}) ) - for edges_cfg in self.config.edges: + for edges_cfg in self.config.get("edges", {}): graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph( graph, edges_cfg.get("attributes", {}) ) diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index bd65d57..79e4fa1 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -14,8 +14,8 @@ from torch_geometric.data.storage import NodeStorage from anemoi.graphs import EARTH_RADIUS -from anemoi.graphs.generate import hexagonal -from anemoi.graphs.generate import icosahedral +from anemoi.graphs.generate import hex_icosahedron +from anemoi.graphs.generate import tri_icosahedron from anemoi.graphs.nodes.builders.from_refined_icosahedron import HexNodes from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaHexNodes from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes @@ -295,7 +295,7 @@ def __init__(self, source_name: str, target_name: str, x_hops: int): self.node_type = None def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: - nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph( + nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph( nodes["_nx_graph"], resolutions=nodes["_resolutions"], x_hops=self.x_hops, @@ -305,7 +305,7 @@ def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: return nodes def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage: - nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph( + nodes["_nx_graph"] = hex_icosahedron.add_edges_to_nx_graph( nodes["_nx_graph"], resolutions=nodes["_resolutions"], x_hops=self.x_hops, diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hex_icosahedron.py similarity index 97% rename from src/anemoi/graphs/generate/hexagonal.py rename to src/anemoi/graphs/generate/hex_icosahedron.py index a70eef5..de24c39 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hex_icosahedron.py @@ -8,7 +8,7 @@ from anemoi.graphs.generate.utils import get_coordinates_ordering -def create_hexagonal_nodes( +def create_hex_nodes( resolution: int, area_mask_builder: KNNAreaMaskBuilder | None = None, ) -> tuple[nx.Graph, np.ndarray, list[int]]: @@ -43,12 +43,12 @@ def create_hexagonal_nodes( aoi_mask = area_mask_builder.get_mask(coords_rad) node_ordering = node_ordering[aoi_mask[node_ordering]] - graph = create_hexagonal_nx_graph_from_coords(nodes, node_ordering) + graph = create_nx_graph_from_hex_coords(nodes, node_ordering) return graph, coords_rad, list(node_ordering) -def create_hexagonal_nx_graph_from_coords(nodes: list[str], node_ordering: np.ndarray) -> nx.Graph: +def create_nx_graph_from_hex_coords(nodes: list[str], node_ordering: np.ndarray) -> nx.Graph: """Add all nodes at a specified refinement level to a graph. Parameters diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/tri_icosahedron.py similarity index 96% rename from src/anemoi/graphs/generate/icosahedral.py rename to src/anemoi/graphs/generate/tri_icosahedron.py index 993cec2..77b033b 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/tri_icosahedron.py @@ -12,10 +12,10 @@ from anemoi.graphs.generate.utils import get_coordinates_ordering -def create_icosahedral_nodes( +def create_tri_nodes( resolution: int, aoi_mask_builder: KNNAreaMaskBuilder | None = None ) -> tuple[nx.DiGraph, np.ndarray, list[int]]: - """Creates a global mesh following AIFS strategy. + """Creates a global mesh from a refined icosahedron. This method relies on the trimesh python library. @@ -46,12 +46,12 @@ def create_icosahedral_nodes( node_ordering = node_ordering[aoi_mask[node_ordering]] # Creates the graph, with the nodes sorted by latitude and longitude. - nx_graph = create_icosahedral_nx_graph_from_coords(coords_rad, node_ordering) + nx_graph = create_nx_graph_from_tri_coords(coords_rad, node_ordering) return nx_graph, coords_rad, list(node_ordering) -def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_ordering: np.ndarray) -> nx.DiGraph: +def create_nx_graph_from_tri_coords(coords_rad: np.ndarray, node_ordering: np.ndarray) -> nx.DiGraph: """Creates the networkx graph from the coordinates and the node ordering.""" graph = nx.DiGraph() for i, coords in enumerate(coords_rad[node_ordering]): diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py index 5f39063..2f02597 100644 --- a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -10,9 +10,9 @@ from anemoi.utils.config import DotDict from torch_geometric.data import HeteroData -from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes -from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes +from anemoi.graphs.generate.hex_icosahedron import create_hex_nodes from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.generate.tri_icosahedron import create_tri_nodes from anemoi.graphs.nodes.builders.base import BaseNodeBuilder LOGGER = logging.getLogger(__name__) @@ -95,7 +95,7 @@ class TriNodes(IcosahedralNodes): """ def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: - return create_icosahedral_nodes(resolution=max(self.resolutions)) + return create_tri_nodes(resolution=max(self.resolutions)) class HexNodes(IcosahedralNodes): @@ -105,7 +105,7 @@ class HexNodes(IcosahedralNodes): """ def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: - return create_hexagonal_nodes(resolution=max(self.resolutions)) + return create_hex_nodes(resolution=max(self.resolutions)) class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): @@ -120,7 +120,7 @@ class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): """ def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: - return create_icosahedral_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) + return create_tri_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): @@ -135,4 +135,4 @@ class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): """ def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: - return create_hexagonal_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) + return create_hex_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) From e2601b4beaadd394d9d928c45f47b0a949b1cf12 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 13 Sep 2024 09:46:06 +0000 Subject: [PATCH 145/156] fix: edge case 1 set of nodes with 1 node attribute --- src/anemoi/graphs/plotting/displots.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/plotting/displots.py b/src/anemoi/graphs/plotting/displots.py index 22316f5..5fe37ee 100644 --- a/src/anemoi/graphs/plotting/displots.py +++ b/src/anemoi/graphs/plotting/displots.py @@ -5,6 +5,7 @@ from typing import Union import matplotlib.pyplot as plt +import numpy as np import torch from torch_geometric.data import HeteroData from torch_geometric.data.storage import EdgeStorage @@ -83,7 +84,9 @@ def plot_distribution_attributes( # Define the layout _, axs = plt.subplots(num_items, dim_attrs, figsize=(10 * num_items, 10)) - if axs.ndim == 1: + if num_items == dim_attrs == 1: + axs = np.array([[axs]]) + elif axs.ndim == 1: axs = axs.reshape(num_items, dim_attrs) for i, (item_name, item_store) in enumerate(graph_items): From 4762f3f1458d33f2c56f7a3c9028ca7ffd6ddb16 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 16 Sep 2024 12:42:42 +0000 Subject: [PATCH 146/156] fix: delete repeated code --- src/anemoi/graphs/nodes/weights.py | 61 ------------------------------ 1 file changed, 61 deletions(-) delete mode 100644 src/anemoi/graphs/nodes/weights.py diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py deleted file mode 100644 index 25419cc..0000000 --- a/src/anemoi/graphs/nodes/weights.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging -from abc import ABC -from abc import abstractmethod -from typing import Optional - -import numpy as np -import torch -from scipy.spatial import SphericalVoronoi -from torch_geometric.data.storage import NodeStorage - -from anemoi.graphs.generate.transforms import to_sphere_xyz -from anemoi.graphs.normalizer import NormalizerMixin - -logger = logging.getLogger(__name__) - - -class BaseWeights(ABC, NormalizerMixin): - """Base class for the weights of the nodes.""" - - def __init__(self, norm: Optional[str] = None): - self.norm = norm - - @abstractmethod - def compute(self, nodes: NodeStorage, *args, **kwargs): ... - - def get_weights(self, *args, **kwargs) -> torch.Tensor: - weights = self.compute(*args, **kwargs) - if weights.ndim == 1: - weights = weights[:, np.newaxis] - norm_weights = self.normalize(weights) - return torch.tensor(norm_weights, dtype=torch.float32) - - -class UniformWeights(BaseWeights): - """Implements a uniform weight for the nodes.""" - - def compute(self, nodes: NodeStorage) -> np.ndarray: - return np.ones(nodes.num_nodes) - - -class AreaWeights(BaseWeights): - """Implements the area of the nodes as the weights.""" - - def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])): - super().__init__(norm=norm) - - # Weighting of the nodes - self.radius = radius - self.centre = centre - - def compute(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: - latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] - points = to_sphere_xyz((latitudes, longitudes)) - sv = SphericalVoronoi(points, self.radius, self.centre) - area_weights = sv.calculate_areas() - logger.debug( - "There are %d of weights, which (unscaled) add up a total weight of %.2f.", - len(area_weights), - np.array(area_weights).sum(), - ) - return area_weights From fac6f36d25a0a63cfa32c8d025f4023c32ad43db Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:32:01 +0200 Subject: [PATCH 147/156] [Feature] Support node masking in edge builder (#50) * feat: support node masking in edge builder --------- Co-authored-by: Helen Theissen --- src/anemoi/graphs/create.py | 10 ++- src/anemoi/graphs/edges/builder.py | 108 ++++++++++++++++++++++++----- 2 files changed, 99 insertions(+), 19 deletions(-) diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 51de9a4..17a46f7 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -42,9 +42,13 @@ def generate_graph(self) -> HeteroData: ) for edges_cfg in self.config.get("edges", {}): - graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph( - graph, edges_cfg.get("attributes", {}) - ) + graph = instantiate( + edges_cfg.edge_builder, + edges_cfg.source_name, + edges_cfg.target_name, + source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None), + target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None), + ).update_graph(graph, edges_cfg.get("attributes", {})) return graph diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 79e4fa1..e43d005 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -9,6 +9,7 @@ import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate +from scipy.sparse import coo_matrix from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage @@ -28,9 +29,17 @@ class BaseEdgeBuilder(ABC): """Base class for edge builders.""" - def __init__(self, source_name: str, target_name: str): + def __init__( + self, + source_name: str, + target_name: str, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): self.source_name = source_name self.target_name = target_name + self.source_mask_attr_name = source_mask_attr_name + self.target_mask_attr_name = target_mask_attr_name @property def name(self) -> tuple[str, str, str]: @@ -125,7 +134,42 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) - return graph -class KNNEdges(BaseEdgeBuilder): +class NodeMaskingMixin: + """Mixin class for masking source/target nodes when building edges.""" + + def get_node_coordinates( + self, source_nodes: NodeStorage, target_nodes: NodeStorage + ) -> tuple[np.ndarray, np.ndarray]: + """Get the node coordinates.""" + source_coords, target_coords = source_nodes.x.numpy(), target_nodes.x.numpy() + + if self.source_mask_attr_name is not None: + source_coords = source_coords[source_nodes[self.source_mask_attr_name].squeeze()] + + if self.target_mask_attr_name is not None: + target_coords = target_coords[target_nodes[self.target_mask_attr_name].squeeze()] + + return source_coords, target_coords + + def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage): + if self.target_mask_attr_name is not None: + target_mask = target_nodes[self.target_mask_attr_name].squeeze() + target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0])) + adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row) + + if self.source_mask_attr_name is not None: + source_mask = source_nodes[self.source_mask_attr_name].squeeze() + source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0])) + adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col) + + if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None: + true_shape = target_nodes.x.shape[0], source_nodes.x.shape[0] + adj_matrix = coo_matrix((adj_matrix.data, (adj_matrix.row, adj_matrix.col)), shape=true_shape) + + return adj_matrix + + +class KNNEdges(BaseEdgeBuilder, NodeMaskingMixin): """Computes KNN based edges and adds them to the graph. Attributes @@ -136,6 +180,10 @@ class KNNEdges(BaseEdgeBuilder): The name of the target nodes. num_nearest_neighbours : int Number of nearest neighbours. + source_mask_attr_name : str | None + The name of the source mask attribute to filter edge connections. + target_mask_attr_name : str | None + The name of the target mask attribute to filter edge connections. Methods ------- @@ -147,22 +195,30 @@ class KNNEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): - super().__init__(source_name, target_name) + def __init__( + self, + source_name: str, + target_name: str, + num_nearest_neighbours: int, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" self.num_nearest_neighbours = num_nearest_neighbours - def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray): + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): """Compute the adjacency matrix for the KNN method. Parameters ---------- - source_nodes : np.ndarray + source_nodes : NodeStorage The source nodes. - target_nodes : np.ndarray + target_nodes : NodeStorage The target nodes. """ + source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes) assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" LOGGER.info( "Using KNN-Edges (with %d nearest neighbours) between %s and %s.", @@ -172,16 +228,20 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(source_nodes.x.numpy()) + nearest_neighbour.fit(source_coords) adj_matrix = nearest_neighbour.kneighbors_graph( - target_nodes.x.numpy(), + target_coords, n_neighbors=self.num_nearest_neighbours, mode="distance", ).tocoo() + + # Post-process the adjacency matrix. Add masked nodes. + adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes) + return adj_matrix -class CutOffEdges(BaseEdgeBuilder): +class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin): """Computes cut-off based edges and adds them to the graph. Attributes @@ -192,6 +252,10 @@ class CutOffEdges(BaseEdgeBuilder): The name of the target nodes. cutoff_factor : float Factor to multiply the grid reference distance to get the cut-off radius. + source_mask_attr_name : str | None + The name of the source mask attribute to filter edge connections. + target_mask_attr_name : str | None + The name of the target mask attribute to filter edge connections. Methods ------- @@ -203,8 +267,15 @@ class CutOffEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, cutoff_factor: float): - super().__init__(source_name, target_name) + def __init__( + self, + source_name: str, + target_name: str, + cutoff_factor: float, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float" assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor @@ -248,6 +319,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor target_nodes : NodeStorage The target nodes. """ + source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes) LOGGER.info( "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", self.radius * EARTH_RADIUS, @@ -256,8 +328,12 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(source_nodes.x) - adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo() + nearest_neighbour.fit(source_coords) + adj_matrix = nearest_neighbour.radius_neighbors_graph(target_coords, radius=self.radius).tocoo() + + # Post-process the adjacency matrix. Add masked nodes. + adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes) + return adj_matrix @@ -286,7 +362,7 @@ class MultiScaleEdges(BaseEdgeBuilder): VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes] - def __init__(self, source_name: str, target_name: str, x_hops: int): + def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs): super().__init__(source_name, target_name) assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same." assert isinstance(x_hops, int), "Number of x_hops must be an integer" @@ -299,7 +375,7 @@ def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: nodes["_nx_graph"], resolutions=nodes["_resolutions"], x_hops=self.x_hops, - aoi_mask_builder=nodes.get("_aoi_mask_builder", None), + area_mask_builder=nodes.get("_area_mask_builder", None), ) return nodes From 2fdb8c3fd00df84bdd663132f1fe524d115e2ca5 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 24 Sep 2024 10:32:07 +0000 Subject: [PATCH 148/156] feat: support area masking of boundary forcing --- src/anemoi/graphs/nodes/builders/from_file.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index 7fa7661..cec3e49 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -55,10 +55,19 @@ class CutOutZarrDatasetNodes(ZarrDatasetNodes): """Nodes from Zarr dataset.""" def __init__( - self, name: str, lam_dataset: str, forcing_dataset: str, thinning: int = 1, adjust: str = "all" + self, + name: str, + lam_dataset: str, + forcing_dataset: str, + thinning: int = 1, + forcing_area: list[float] | None = None, + adjust: str = "all", ) -> None: dataset_config = { - "cutout": [{"dataset": lam_dataset, "thinning": thinning}, {"dataset": forcing_dataset}], + "cutout": [ + {"dataset": lam_dataset, "thinning": thinning}, + {"dataset": forcing_dataset, "area": forcing_area}, + ], "adjust": adjust, } super().__init__(dataset_config, name) From 669d201194ab605acdb9083e2d8201cec7e9d047 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 24 Sep 2024 10:41:29 +0000 Subject: [PATCH 149/156] docs: added CutOutZarrDatasetNodes docstring --- src/anemoi/graphs/nodes/builders/from_file.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index cec3e49..e9dacb7 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -52,7 +52,32 @@ def get_coordinates(self) -> torch.Tensor: class CutOutZarrDatasetNodes(ZarrDatasetNodes): - """Nodes from Zarr dataset.""" + """Nodes from Zarr dataset. + + Attributes + ---------- + lam_dataset : str + The limited area dataset. + forcing_dataset : str + The forcing dataset. + thinning : int, optional + The thinning factor. Defaults to 1, which means no thinning. + forcing_area : list[float], optional + The area of the forcing dataset. Specify the longitude and + latitudes boundaries as (north, west, south, east). Defaults + to None, which means the forcing dataset is not cropped. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ def __init__( self, From 0deb8864bf7a862c3ffe6e7709ae3b6a78112686 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 24 Sep 2024 11:15:03 +0000 Subject: [PATCH 150/156] fix: cast to tuple --- src/anemoi/graphs/nodes/builders/from_file.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index e9dacb7..0eeea9a 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -88,6 +88,11 @@ def __init__( forcing_area: list[float] | None = None, adjust: str = "all", ) -> None: + if forcing_area is not None: + forcing_area = tuple(forcing_area) + assert len(forcing_area) == 4, "The forcing area must be a list of 4 elements (north, west, south, east)." + LOGGER.info("Forcing dataset is cropped to area: %s", forcing_area) + dataset_config = { "cutout": [ {"dataset": lam_dataset, "thinning": thinning}, From 996514dd52f793b140433e260916ec7c68979ce7 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 3 Oct 2024 13:38:07 +0000 Subject: [PATCH 151/156] feat: update CHANGELOG --- CHANGELOG.md | 1 + src/anemoi/graphs/generate/hex_icosahedron.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92471ce..208baac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Keep it human-readable, your future self will thank you! ### Changed - ci: small fixes and updates pre-commit, downsteam-ci (#49) +- feat: New argument 'forcing_area' in the CutOutZarDatasetNodes class. (#52) ## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03 diff --git a/src/anemoi/graphs/generate/hex_icosahedron.py b/src/anemoi/graphs/generate/hex_icosahedron.py index de24c39..3306164 100644 --- a/src/anemoi/graphs/generate/hex_icosahedron.py +++ b/src/anemoi/graphs/generate/hex_icosahedron.py @@ -112,8 +112,6 @@ def add_edges_to_nx_graph( depth_children : int The number of resolution levels to consider for the connections of children. Defaults to 1, which includes connections up to the next resolution level. - aoi_mask_builder : KNNAreaMaskBuilder - NearestNeighbors with the cloud of points to limit the mesh area, by default None. Returns ------- From b63222112da075886f34e5114cbbcb3e8bfdc0f4 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 3 Oct 2024 13:48:06 +0000 Subject: [PATCH 152/156] feat: support None as default argument for forcing area --- src/anemoi/graphs/nodes/builders/from_file.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index 0eeea9a..0f64f33 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -88,18 +88,16 @@ def __init__( forcing_area: list[float] | None = None, adjust: str = "all", ) -> None: + lam_config = {"dataset": lam_dataset, "thinning": thinning} + forcing_config = {"dataset": forcing_dataset} + + # Add area argument to crop the boundary forcing if forcing_area is not None: - forcing_area = tuple(forcing_area) + forcing_config["area"] = tuple(forcing_area) assert len(forcing_area) == 4, "The forcing area must be a list of 4 elements (north, west, south, east)." LOGGER.info("Forcing dataset is cropped to area: %s", forcing_area) - dataset_config = { - "cutout": [ - {"dataset": lam_dataset, "thinning": thinning}, - {"dataset": forcing_dataset, "area": forcing_area}, - ], - "adjust": adjust, - } + dataset_config = {"cutout": [lam_config, forcing_config], "adjust": adjust} super().__init__(dataset_config, name) self.n_cutout, self.n_other = self.dataset.grids From 5dedb0a43d8de07012e406885c32dc314e64ed34 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 15 Oct 2024 06:59:53 +0000 Subject: [PATCH 153/156] feat: refactor args of CutOutZarrDatasetNodes --- src/anemoi/graphs/nodes/builders/from_file.py | 42 ++++++------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index 0f64f33..87e710a 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -7,6 +7,7 @@ import torch from anemoi.datasets import open_dataset from anemoi.utils.config import DotDict +from omegaconf import OmegaConf from torch_geometric.data import HeteroData from anemoi.graphs.generate.masks import KNNAreaMaskBuilder @@ -56,16 +57,12 @@ class CutOutZarrDatasetNodes(ZarrDatasetNodes): Attributes ---------- - lam_dataset : str - The limited area dataset. - forcing_dataset : str - The forcing dataset. - thinning : int, optional - The thinning factor. Defaults to 1, which means no thinning. - forcing_area : list[float], optional - The area of the forcing dataset. Specify the longitude and - latitudes boundaries as (north, west, south, east). Defaults - to None, which means the forcing dataset is not cropped. + dataset : DotDict + The limited area dataset. Its schema is: + { + "cutout": [lam_dataset_config, forcing_dataset_config], + "adjust": ..., + } Methods ------- @@ -79,26 +76,11 @@ class CutOutZarrDatasetNodes(ZarrDatasetNodes): Update the graph with new nodes and attributes. """ - def __init__( - self, - name: str, - lam_dataset: str, - forcing_dataset: str, - thinning: int = 1, - forcing_area: list[float] | None = None, - adjust: str = "all", - ) -> None: - lam_config = {"dataset": lam_dataset, "thinning": thinning} - forcing_config = {"dataset": forcing_dataset} - - # Add area argument to crop the boundary forcing - if forcing_area is not None: - forcing_config["area"] = tuple(forcing_area) - assert len(forcing_area) == 4, "The forcing area must be a list of 4 elements (north, west, south, east)." - LOGGER.info("Forcing dataset is cropped to area: %s", forcing_area) - - dataset_config = {"cutout": [lam_config, forcing_config], "adjust": adjust} - super().__init__(dataset_config, name) + def __init__(self, dataset: DotDict, name: str) -> None: + assert ( + "cutout" in dataset + ), f"The 'cutout' key must be present in the dataset configuration for {self.__class__}." + super().__init__(OmegaConf.to_container(dataset), name) self.n_cutout, self.n_other = self.dataset.grids def register_attributes(self, graph: HeteroData, config: DotDict) -> None: From 820ca2f48c80e2467b1ecb6a47fba1539d71d496 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 16 Oct 2024 11:26:29 +0000 Subject: [PATCH 154/156] fix: tests --- tests/nodes/test_cutout_nodes.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/nodes/test_cutout_nodes.py b/tests/nodes/test_cutout_nodes.py index a18424d..e6484ae 100644 --- a/tests/nodes/test_cutout_nodes.py +++ b/tests/nodes/test_cutout_nodes.py @@ -1,18 +1,19 @@ import pytest import torch +from omegaconf import OmegaConf from torch_geometric.data import HeteroData from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights from anemoi.graphs.nodes.builders import from_file +dataset_cfg = OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]}) + def test_init(mocker, mock_zarr_dataset_cutout): """Test CutOutZarrDatasetNodes initialization.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" - ) + node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes") assert isinstance(node_builder, from_file.BaseNodeBuilder) assert isinstance(node_builder, from_file.CutOutZarrDatasetNodes) @@ -20,16 +21,14 @@ def test_init(mocker, mock_zarr_dataset_cutout): def test_fail_init(): """Test CutOutZarrDatasetNodes initialization with invalid resolution.""" - with pytest.raises(TypeError): + with pytest.raises(AssertionError): from_file.CutOutZarrDatasetNodes("global_dataset.zarr", name="test_nodes") def test_register_nodes(mocker, mock_zarr_dataset_cutout): """Test CutOutZarrDatasetNodes register correctly the nodes.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" - ) + node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes") graph = HeteroData() graph = node_builder.register_nodes(graph) @@ -44,9 +43,7 @@ def test_register_nodes(mocker, mock_zarr_dataset_cutout): def test_register_attributes(mocker, mock_zarr_dataset_cutout, graph_with_nodes: HeteroData, attr_class): """Test CutOutZarrDatasetNodes register correctly the weights.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" - ) + node_builder = from_file.CutOutZarrDatasetNodes(dataset_cfg, name="test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, config) From d1dd70006e2413a0e5bd83569e290006862d49c3 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 16 Oct 2024 12:00:59 +0000 Subject: [PATCH 155/156] fix: update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9aa6eb..81cfd31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,7 @@ Keep it human-readable, your future self will thank you! ### Changed - ci: small fixes and updates pre-commit, downsteam-ci (#49) -- feat: New argument 'forcing_area' in the CutOutZarDatasetNodes class. (#52) +- feat: Refactored CutOutZarDatasetNodes class. It now supports area and min_distance_km arguments. (#52) - Update CODEOWNERS ## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03 From 5a0ee5e72eb5a2b714ed4cb7e674bc4c6afc38b6 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 17 Oct 2024 12:44:15 +0000 Subject: [PATCH 156/156] fix: type hint --- src/anemoi/graphs/nodes/builders/from_file.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index 87e710a..8f4a9de 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -7,6 +7,7 @@ import torch from anemoi.datasets import open_dataset from anemoi.utils.config import DotDict +from omegaconf import DictConfig from omegaconf import OmegaConf from torch_geometric.data import HeteroData @@ -57,7 +58,7 @@ class CutOutZarrDatasetNodes(ZarrDatasetNodes): Attributes ---------- - dataset : DotDict + dataset : DictConfig The limited area dataset. Its schema is: { "cutout": [lam_dataset_config, forcing_dataset_config], @@ -76,7 +77,7 @@ class CutOutZarrDatasetNodes(ZarrDatasetNodes): Update the graph with new nodes and attributes. """ - def __init__(self, dataset: DotDict, name: str) -> None: + def __init__(self, dataset: DictConfig, name: str) -> None: assert ( "cutout" in dataset ), f"The 'cutout' key must be present in the dataset configuration for {self.__class__}."