Skip to content

Commit

Permalink
[feature] Expand node attributes (#68)
Browse files Browse the repository at this point in the history
* feat: set hidden attributes as class attribute

* feat: remove CutOutZarrDatasetNodes

* feat: update test

* fix: tests

* featL update changelog.md

* feat: expand node attributes (and)

* fix: rename n -> num

* fix: update Nonmissing

* fix: remove ds

* fix: tests
  • Loading branch information
JPXKQX authored Oct 29, 2024
1 parent af6a840 commit 009744e
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 92 deletions.
17 changes: 10 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@ Keep it human-readable, your future self will thank you!

### Added

- fix: bug in color when plotting isolated nodes

### Added

- Add anemoi-transform link to documentation

### Added
- ci: hpc-config, CODEOWNERS (#49)
- feat: New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30)
- feat: New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. (#30)
Expand All @@ -27,10 +20,20 @@ Keep it human-readable, your future self will thank you!
- feat: New method update_graph(graph) in the GraphCreator class. (#60)
- feat: New class StretchedTriNodes to create a stretched mesh. (#51)
- feat: Expanded MultiScaleEdges to support multi-scale connections in stretched graphs. (#51)
- fix: bug in color when plotting isolated nodes (#63)
- Add anemoi-transform link to documentation (#59)
- Added `CutOutMask` class to create a mask for a cutout. (#68)
- Added `MissingZarrVariable` and `NotMissingZarrVariable` classes to create a mask for missing zarr variables. (#68)
- feat: Add CONTRIBUTORS.md file. (#72)

### Changed
- ci: small fixes and updates pre-commit, downsteam-ci (#49)
- Update CODEOWNERS (#61)
- ci: extened python versions to include 3.11 and 3.12 (#66)
- Update copyright notice (#67)

### Removed
- Remove `CutOutZarrDatasetNodes` class. (#68)
- Update CODEOWNERS
- Fix pre-commit regex
- ci: extened python versions to include 3.11 and 3.12
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/graphs/generate/tri_icosahedron.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def add_edges_to_nx_graph(

# Limit area of mesh points.
if area_mask_builder is not None:
aoi_mask = area_mask_builder.get_mask(r_vertices_rad)
valid_nodes = np.where(aoi_mask)[0]
area_mask = area_mask_builder.get_mask(r_vertices_rad)
valid_nodes = np.where(area_mask)[0]
else:
valid_nodes = None

Expand Down
2 changes: 0 additions & 2 deletions src/anemoi/graphs/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .builders.from_file import CutOutZarrDatasetNodes
from .builders.from_file import LimitedAreaNPZFileNodes
from .builders.from_file import NPZFileNodes
from .builders.from_file import ZarrDatasetNodes
Expand All @@ -17,7 +16,6 @@
"HexNodes",
"HEALPixNodes",
"LimitedAreaHEALPixNodes",
"CutOutZarrDatasetNodes",
"LimitedAreaNPZFileNodes",
"LimitedAreaTriNodes",
"LimitedAreaHexNodes",
Expand Down
129 changes: 106 additions & 23 deletions src/anemoi/graphs/nodes/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import logging
from abc import ABC
from abc import abstractmethod
from typing import Type

import numpy as np
import torch
from anemoi.datasets import open_dataset
from scipy.spatial import SphericalVoronoi
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage
Expand All @@ -25,14 +27,15 @@
LOGGER = logging.getLogger(__name__)


class BaseWeights(ABC, NormaliserMixin):
class BaseNodeAttribute(ABC, NormaliserMixin):
"""Base class for the weights of the nodes."""

def __init__(self, norm: str | None = None) -> None:
def __init__(self, norm: str | None = None, dtype: str = "float32") -> None:
self.norm = norm
self.dtype = dtype

@abstractmethod
def get_raw_values(self, nodes: NodeStorage, *args, **kwargs): ...
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: ...

def post_process(self, values: np.ndarray) -> torch.Tensor:
"""Post-process the values."""
Expand All @@ -41,33 +44,31 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:

norm_values = self.normalise(values)

return torch.tensor(norm_values, dtype=torch.float32)
return torch.tensor(norm_values.astype(self.dtype))

def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch.Tensor:
"""Get the node weights.
def compute(self, graph: HeteroData, nodes_name: str, **kwargs) -> torch.Tensor:
"""Get the nodes attribute.
Parameters
----------
graph : HeteroData
Graph.
nodes_name : str
Name of the nodes.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.
Returns
-------
torch.Tensor
Weights associated to the nodes.
Attributes associated to the nodes.
"""
nodes = graph[nodes_name]
weights = self.get_raw_values(nodes, *args, **kwargs)
return self.post_process(weights)
attributes = self.get_raw_values(nodes, **kwargs)
return self.post_process(attributes)


class UniformWeights(BaseWeights):
class UniformWeights(BaseNodeAttribute):
"""Implements a uniform weight for the nodes.
Methods
Expand All @@ -76,27 +77,25 @@ class UniformWeights(BaseWeights):
Compute the area attributes for each node.
"""

def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
"""Compute the weights.
Parameters
----------
nodes : NodeStorage
Nodes of the graph.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.
Returns
-------
np.ndarray
Weights.
Attributes.
"""
return np.ones(nodes.num_nodes)


class AreaWeights(BaseWeights):
class AreaWeights(BaseNodeAttribute):
"""Implements the area of the nodes as the weights.
Attributes
Expand All @@ -114,12 +113,18 @@ class AreaWeights(BaseWeights):
Compute the area attributes for each node.
"""

def __init__(self, norm: str | None = None, radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])) -> None:
super().__init__(norm)
def __init__(
self,
norm: str | None = None,
radius: float = 1.0,
centre: np.ndarray = np.array([0, 0, 0]),
dtype: str = "float32",
) -> None:
super().__init__(norm, dtype)
self.radius = radius
self.centre = centre

def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
"""Compute the area associated to each node.
It uses Voronoi diagrams to compute the area of each node.
Expand All @@ -128,15 +133,13 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
----------
nodes : NodeStorage
Nodes of the graph.
args : tuple
Additional arguments.
kwargs : dict
Additional keyword arguments.
Returns
-------
np.ndarray
Weights.
Attributes.
"""
latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes)))
Expand All @@ -148,3 +151,83 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
np.array(area_weights).sum(),
)
return area_weights


class BooleanBaseNodeAttribute(BaseNodeAttribute, ABC):
"""Base class for boolean node attributes."""

def __init__(self) -> None:
super().__init__(norm=None, dtype="bool")


class NonmissingZarrVariable(BooleanBaseNodeAttribute):
"""Mask of valid (not missing) values of a Zarr dataset variable.
It reads a variable from a Zarr dataset and returns a boolean mask of nonmissing values in the first timestep.
Attributes
----------
variable : str
Variable to read from the Zarr dataset.
norm : str
Normalization of the weights.
Methods
-------
compute(self, graph, nodes_name)
Compute the attribute for each node.
"""

def __init__(self, variable: str) -> None:
super().__init__()
self.variable = variable

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
assert (
nodes["node_type"] == "ZarrDatasetNodes"
), f"{self.__class__.__name__} can only be used with ZarrDatasetNodes."
ds = open_dataset(nodes["_dataset"], select=self.variable)[0].squeeze()
return ~np.isnan(ds)


class CutOutMask(BooleanBaseNodeAttribute):
"""Cut out mask."""

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
assert isinstance(nodes["_dataset"], dict), "The 'dataset' attribute must be a dictionary."
assert "cutout" in nodes["_dataset"], "The 'dataset' attribute must contain a 'cutout' key."
num_lam, num_other = open_dataset(nodes["_dataset"]).grids
return np.array([True] * num_lam + [False] * num_other, dtype=bool)


class BooleanOperation(BooleanBaseNodeAttribute, ABC):
"""Base class for boolean operations."""

def __init__(self, masks: list[str | Type[BooleanBaseNodeAttribute]]) -> None:
super().__init__()
self.masks = masks

@staticmethod
def get_mask_values(mask: str | Type[BaseNodeAttribute], nodes: NodeStorage, **kwargs) -> np.array:
if isinstance(mask, str):
attributes = nodes[mask]
assert (
attributes.dtype == "bool"
), f"The mask attribute '{mask}' must be a boolean but is {attributes.dtype}."
return attributes

return mask.get_raw_values(nodes, **kwargs)


class BooleanAndMask(BooleanOperation):
"""Boolean AND mask."""

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
return np.logical_and.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks])


class BooleanOrMask(BooleanOperation):
"""Boolean OR mask."""

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
return np.logical_or.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks])
7 changes: 6 additions & 1 deletion src/anemoi/graphs/nodes/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ class BaseNodeBuilder(ABC):
----------
name : str
name of the nodes, key for the nodes in the HeteroData graph object.
aoi_mask_builder : KNNAreaMaskBuilder
area_mask_builder : KNNAreaMaskBuilder
The area of interest mask builder, if any. Defaults to None.
"""

hidden_attributes: set[str] = set()

def __init__(self, name: str) -> None:
self.name = name
self.area_mask_builder = None
Expand Down Expand Up @@ -68,6 +70,9 @@ def register_attributes(self, graph: HeteroData, config: DotDict | None = None)
HeteroData
The graph with the registered attributes.
"""
for hidden_attr in self.hidden_attributes:
graph[self.name][f"_{hidden_attr}"] = getattr(self, hidden_attr)

for attr_name, attr_config in config.items():
graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name)

Expand Down
34 changes: 8 additions & 26 deletions src/anemoi/graphs/nodes/builders/from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import numpy as np
import torch
from anemoi.datasets import open_dataset
from anemoi.utils.config import DotDict
from omegaconf import DictConfig
from omegaconf import OmegaConf
from torch_geometric.data import HeteroData

from anemoi.graphs.generate.masks import KNNAreaMaskBuilder
Expand All @@ -29,7 +30,7 @@ class ZarrDatasetNodes(BaseNodeBuilder):
Attributes
----------
dataset : zarr.core.Array
dataset : str | DictConfig
The dataset.
Methods
Expand All @@ -44,10 +45,11 @@ class ZarrDatasetNodes(BaseNodeBuilder):
Update the graph with new nodes and attributes.
"""

def __init__(self, dataset: DotDict, name: str) -> None:
def __init__(self, dataset: DictConfig, name: str) -> None:
LOGGER.info("Reading the dataset from %s.", dataset)
self.dataset = open_dataset(dataset)
self.dataset = dataset if isinstance(dataset, str) else OmegaConf.to_container(dataset)
super().__init__(name)
self.hidden_attributes = BaseNodeBuilder.hidden_attributes | {"dataset"}

def get_coordinates(self) -> torch.Tensor:
"""Get the coordinates of the nodes.
Expand All @@ -57,28 +59,8 @@ def get_coordinates(self) -> torch.Tensor:
torch.Tensor of shape (num_nodes, 2)
A 2D tensor with the coordinates, in radians.
"""
return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes)


class CutOutZarrDatasetNodes(ZarrDatasetNodes):
"""Nodes from Zarr dataset."""

def __init__(
self, name: str, lam_dataset: str, forcing_dataset: str, thinning: int = 1, adjust: str = "all"
) -> None:
dataset_config = {
"cutout": [{"dataset": lam_dataset, "thinning": thinning}, {"dataset": forcing_dataset}],
"adjust": adjust,
}
super().__init__(dataset_config, name)
self.n_cutout, self.n_other = self.dataset.grids

def register_attributes(self, graph: HeteroData, config: DotDict) -> None:
# this is a mask to cutout the LAM area
graph[self.name]["cutout"] = torch.tensor([True] * self.n_cutout + [False] * self.n_other, dtype=bool).reshape(
(-1, 1)
)
return super().register_attributes(graph, config)
dataset = open_dataset(self.dataset)
return self.reshape_coords(dataset.latitudes, dataset.longitudes)


class NPZFileNodes(BaseNodeBuilder):
Expand Down
Loading

0 comments on commit 009744e

Please sign in to comment.