Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Expand node attributes #68

Merged
merged 15 commits into from
Oct 29, 2024
17 changes: 10 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@ Keep it human-readable, your future self will thank you!

### Added

- fix: bug in color when plotting isolated nodes

### Added

- Add anemoi-transform link to documentation

### Added
- ci: hpc-config, CODEOWNERS (#49)
- feat: New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30)
- feat: New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. (#30)
Expand All @@ -27,10 +20,20 @@ Keep it human-readable, your future self will thank you!
- feat: New method update_graph(graph) in the GraphCreator class. (#60)
- feat: New class StretchedTriNodes to create a stretched mesh. (#51)
- feat: Expanded MultiScaleEdges to support multi-scale connections in stretched graphs. (#51)
- fix: bug in color when plotting isolated nodes (#63)
- Add anemoi-transform link to documentation (#59)
- Added `CutOutMask` class to create a mask for a cutout. (#68)
- Added `MissingZarrVariable` and `NotMissingZarrVariable` classes to create a mask for missing zarr variables. (#68)
- feat: Add CONTRIBUTORS.md file. (#72)

### Changed
- ci: small fixes and updates pre-commit, downsteam-ci (#49)
- Update CODEOWNERS (#61)
- ci: extened python versions to include 3.11 and 3.12 (#66)
- Update copyright notice (#67)

### Removed
- Remove `CutOutZarrDatasetNodes` class. (#68)
- Update CODEOWNERS
- Fix pre-commit regex
- ci: extened python versions to include 3.11 and 3.12
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/graphs/generate/tri_icosahedron.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def add_edges_to_nx_graph(

# Limit area of mesh points.
if area_mask_builder is not None:
aoi_mask = area_mask_builder.get_mask(r_vertices_rad)
valid_nodes = np.where(aoi_mask)[0]
area_mask = area_mask_builder.get_mask(r_vertices_rad)
valid_nodes = np.where(area_mask)[0]
else:
valid_nodes = None

Expand Down
2 changes: 0 additions & 2 deletions src/anemoi/graphs/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .builders.from_file import CutOutZarrDatasetNodes
from .builders.from_file import LimitedAreaNPZFileNodes
from .builders.from_file import NPZFileNodes
from .builders.from_file import ZarrDatasetNodes
Expand All @@ -17,7 +16,6 @@
"HexNodes",
"HEALPixNodes",
"LimitedAreaHEALPixNodes",
"CutOutZarrDatasetNodes",
"LimitedAreaNPZFileNodes",
"LimitedAreaTriNodes",
"LimitedAreaHexNodes",
Expand Down
129 changes: 106 additions & 23 deletions src/anemoi/graphs/nodes/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import logging
from abc import ABC
from abc import abstractmethod
from typing import Type

import numpy as np
import torch
from anemoi.datasets import open_dataset
from scipy.spatial import SphericalVoronoi
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage
Expand All @@ -25,14 +27,15 @@
LOGGER = logging.getLogger(__name__)


class BaseWeights(ABC, NormaliserMixin):
class BaseNodeAttribute(ABC, NormaliserMixin):
"""Base class for the weights of the nodes."""

def __init__(self, norm: str | None = None) -> None:
def __init__(self, norm: str | None = None, dtype: str = "float32") -> None:
self.norm = norm
self.dtype = dtype

@abstractmethod
def get_raw_values(self, nodes: NodeStorage, *args, **kwargs): ...
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: ...

def post_process(self, values: np.ndarray) -> torch.Tensor:
"""Post-process the values."""
Expand All @@ -41,33 +44,31 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:

norm_values = self.normalise(values)

return torch.tensor(norm_values, dtype=torch.float32)
return torch.tensor(norm_values.astype(self.dtype))

def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch.Tensor:
"""Get the node weights.
def compute(self, graph: HeteroData, nodes_name: str, **kwargs) -> torch.Tensor:
"""Get the nodes attribute.

Parameters
----------
graph : HeteroData
Graph.
nodes_name : str
Name of the nodes.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.

Returns
-------
torch.Tensor
Weights associated to the nodes.
Attributes associated to the nodes.
"""
nodes = graph[nodes_name]
weights = self.get_raw_values(nodes, *args, **kwargs)
return self.post_process(weights)
attributes = self.get_raw_values(nodes, **kwargs)
return self.post_process(attributes)


class UniformWeights(BaseWeights):
class UniformWeights(BaseNodeAttribute):
"""Implements a uniform weight for the nodes.

Methods
Expand All @@ -76,27 +77,25 @@ class UniformWeights(BaseWeights):
Compute the area attributes for each node.
"""

def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
"""Compute the weights.

Parameters
----------
nodes : NodeStorage
Nodes of the graph.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.

Returns
-------
np.ndarray
Weights.
Attributes.
"""
return np.ones(nodes.num_nodes)


class AreaWeights(BaseWeights):
class AreaWeights(BaseNodeAttribute):
"""Implements the area of the nodes as the weights.

Attributes
Expand All @@ -114,12 +113,18 @@ class AreaWeights(BaseWeights):
Compute the area attributes for each node.
"""

def __init__(self, norm: str | None = None, radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])) -> None:
super().__init__(norm)
def __init__(
self,
norm: str | None = None,
radius: float = 1.0,
centre: np.ndarray = np.array([0, 0, 0]),
dtype: str = "float32",
) -> None:
super().__init__(norm, dtype)
self.radius = radius
self.centre = centre

def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
"""Compute the area associated to each node.

It uses Voronoi diagrams to compute the area of each node.
Expand All @@ -128,15 +133,13 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
----------
nodes : NodeStorage
Nodes of the graph.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.

Returns
-------
np.ndarray
Weights.
Attributes.
"""
latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes)))
Expand All @@ -148,3 +151,83 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
np.array(area_weights).sum(),
)
return area_weights


class BooleanBaseNodeAttribute(BaseNodeAttribute, ABC):
"""Base class for boolean node attributes."""

def __init__(self) -> None:
super().__init__(norm=None, dtype="bool")


class NonmissingZarrVariable(BooleanBaseNodeAttribute):
"""Mask of valid (not missing) values of a Zarr dataset variable.

It reads a variable from a Zarr dataset and returns a boolean mask of nonmissing values in the first timestep.

Attributes
----------
variable : str
Variable to read from the Zarr dataset.
norm : str
Normalization of the weights.

Methods
-------
compute(self, graph, nodes_name)
Compute the attribute for each node.
"""

def __init__(self, variable: str) -> None:
super().__init__()
self.variable = variable

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
assert (
nodes["node_type"] == "ZarrDatasetNodes"
), f"{self.__class__.__name__} can only be used with ZarrDatasetNodes."
ds = open_dataset(nodes["_dataset"], select=self.variable)[0].squeeze()
return ~np.isnan(ds)


class CutOutMask(BooleanBaseNodeAttribute):
"""Cut out mask."""

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
assert isinstance(nodes["_dataset"], dict), "The 'dataset' attribute must be a dictionary."
assert "cutout" in nodes["_dataset"], "The 'dataset' attribute must contain a 'cutout' key."
num_lam, num_other = open_dataset(nodes["_dataset"]).grids
return np.array([True] * num_lam + [False] * num_other, dtype=bool)


class BooleanOperation(BooleanBaseNodeAttribute, ABC):
"""Base class for boolean operations."""

def __init__(self, masks: list[str | Type[BooleanBaseNodeAttribute]]) -> None:
super().__init__()
self.masks = masks

@staticmethod
def get_mask_values(mask: str | Type[BaseNodeAttribute], nodes: NodeStorage, **kwargs) -> np.array:
if isinstance(mask, str):
attributes = nodes[mask]
assert (
attributes.dtype == "bool"
), f"The mask attribute '{mask}' must be a boolean but is {attributes.dtype}."
return attributes

return mask.get_raw_values(nodes, **kwargs)


class BooleanAndMask(BooleanOperation):
"""Boolean AND mask."""

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
return np.logical_and.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks])


class BooleanOrMask(BooleanOperation):
"""Boolean OR mask."""

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
return np.logical_or.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks])
7 changes: 6 additions & 1 deletion src/anemoi/graphs/nodes/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ class BaseNodeBuilder(ABC):
----------
name : str
name of the nodes, key for the nodes in the HeteroData graph object.
aoi_mask_builder : KNNAreaMaskBuilder
area_mask_builder : KNNAreaMaskBuilder
The area of interest mask builder, if any. Defaults to None.
"""

hidden_attributes: set[str] = set()

def __init__(self, name: str) -> None:
self.name = name
self.area_mask_builder = None
Expand Down Expand Up @@ -68,6 +70,9 @@ def register_attributes(self, graph: HeteroData, config: DotDict | None = None)
HeteroData
The graph with the registered attributes.
"""
for hidden_attr in self.hidden_attributes:
graph[self.name][f"_{hidden_attr}"] = getattr(self, hidden_attr)

for attr_name, attr_config in config.items():
graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name)

Expand Down
34 changes: 8 additions & 26 deletions src/anemoi/graphs/nodes/builders/from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import numpy as np
import torch
from anemoi.datasets import open_dataset
from anemoi.utils.config import DotDict
from omegaconf import DictConfig
from omegaconf import OmegaConf
from torch_geometric.data import HeteroData

from anemoi.graphs.generate.masks import KNNAreaMaskBuilder
Expand All @@ -29,7 +30,7 @@ class ZarrDatasetNodes(BaseNodeBuilder):

Attributes
----------
dataset : zarr.core.Array
dataset : str | DictConfig
The dataset.

Methods
Expand All @@ -44,10 +45,11 @@ class ZarrDatasetNodes(BaseNodeBuilder):
Update the graph with new nodes and attributes.
"""

def __init__(self, dataset: DotDict, name: str) -> None:
def __init__(self, dataset: DictConfig, name: str) -> None:
LOGGER.info("Reading the dataset from %s.", dataset)
self.dataset = open_dataset(dataset)
self.dataset = dataset if isinstance(dataset, str) else OmegaConf.to_container(dataset)
super().__init__(name)
self.hidden_attributes = BaseNodeBuilder.hidden_attributes | {"dataset"}

def get_coordinates(self) -> torch.Tensor:
"""Get the coordinates of the nodes.
Expand All @@ -57,28 +59,8 @@ def get_coordinates(self) -> torch.Tensor:
torch.Tensor of shape (num_nodes, 2)
A 2D tensor with the coordinates, in radians.
"""
return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes)


class CutOutZarrDatasetNodes(ZarrDatasetNodes):
"""Nodes from Zarr dataset."""

def __init__(
self, name: str, lam_dataset: str, forcing_dataset: str, thinning: int = 1, adjust: str = "all"
) -> None:
dataset_config = {
"cutout": [{"dataset": lam_dataset, "thinning": thinning}, {"dataset": forcing_dataset}],
"adjust": adjust,
}
super().__init__(dataset_config, name)
self.n_cutout, self.n_other = self.dataset.grids

def register_attributes(self, graph: HeteroData, config: DotDict) -> None:
# this is a mask to cutout the LAM area
graph[self.name]["cutout"] = torch.tensor([True] * self.n_cutout + [False] * self.n_other, dtype=bool).reshape(
(-1, 1)
)
return super().register_attributes(graph, config)
dataset = open_dataset(self.dataset)
return self.reshape_coords(dataset.latitudes, dataset.longitudes)


class NPZFileNodes(BaseNodeBuilder):
Expand Down
Loading
Loading