Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support icon graphs #53

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
100c251
feat: topology-based encoder/processor/decoder graphs derived from an…
fprill Sep 26, 2024
9848cf0
docs: attempt to set a correct license header.
fprill Sep 26, 2024
69ae2fa
doc: changed the docstring style to the Numpy-style instead of Google…
fprill Sep 30, 2024
d858d0e
refactor: changed variable name `verts` into `vertices`.
fprill Sep 30, 2024
5e21ff5
fix: changed float division to double slash operator.
fprill Sep 30, 2024
3d662d1
refactor: removed unnecessary `else` branch.
fprill Sep 30, 2024
cfd82a9
refactor: more descriptive variable name in for loop.
fprill Sep 30, 2024
48230ec
refactor: renamed variable (ignoring the minus sign of `phi`).
fprill Sep 30, 2024
6ad883f
refactor: changed name of temporary variable.
fprill Sep 30, 2024
1862a66
refactor: remove unnecessary `else´branch.
fprill Sep 30, 2024
6f6e501
refactor: added type annotation for completeness.
fprill Sep 30, 2024
9bc7ae0
refactor: remove default in function argument.
fprill Sep 30, 2024
0fae14d
refactor: change argument name to a more understandable name.
fprill Sep 30, 2024
a44fb01
refactor: remove redundant code.
fprill Sep 30, 2024
25630cc
refactor: more Pythonic if-else statement.
fprill Sep 30, 2024
612afe5
refactor: more Pythonic if-else statement.
fprill Sep 30, 2024
368427c
refactor: use more appropriate `LOGGER.debug` instead of verbosity flag.
fprill Sep 30, 2024
5c76e10
refactor: more appropriate variable name.
fprill Sep 30, 2024
f6d63f0
refactor: more appropriate variable name.
fprill Sep 30, 2024
8efd2c6
refactor: unified the three ICON Edgebuilders and the two Nodebuilders.
fprill Sep 30, 2024
1961150
refactor: more verbose but also clearer names for variables in mesh c…
fprill Sep 30, 2024
30dee38
remove: removed obsolete function `set_constant_edge_id`.
fprill Sep 30, 2024
92cb023
refactor: replaced the sequential ID counter by a UUID.
fprill Sep 30, 2024
9937a61
revert change of copyright notice
MeraX Oct 16, 2024
ca67c7d
Merge branch 'develop' into feature/support-icon-graphs
fprill Nov 6, 2024
75057c3
[fix] add encoder & processor edges to the __all__ variable
fprill Nov 6, 2024
dcff399
[refactor] move auxiliary functions to utils.py in graphs/generate/
fprill Nov 6, 2024
9b19fad
[doc] added empty torso for class documentation.
fprill Nov 6, 2024
611e264
[fix] fixed interfaces (masks).
fprill Nov 8, 2024
797a6a2
[refactor] remove edge attribute calculation from this PR.
fprill Nov 8, 2024
37cb87a
[chore] adjust copyright notice.
fprill Nov 8, 2024
90a1721
[doc] elaborate on icon mesh classes in rst file.
fprill Nov 8, 2024
b45e21e
Add Icon tests
MeraX Nov 11, 2024
86fc909
Merge remote-tracking branch 'github/develop' into feature/support-ic…
MeraX Nov 11, 2024
63ea1b0
update Change log
MeraX Nov 11, 2024
4aaa9ff
Add __future__ annotations
MeraX Nov 11, 2024
01a94f0
fix change log
MeraX Nov 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Please add your functional changes to the appropriate section in the PR.
Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...HEAD)
- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)

## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08

Expand Down
64 changes: 64 additions & 0 deletions docs/graphs/node_coordinates/icon_mesh.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
####################################
Triangular Mesh with ICON Topology
####################################

The classes `ICONMultimeshNodes` and `ICONCellGridNodes` define node
sets based on an ICON icosahedral mesh:

- class `ICONCellGridNodes`: data grid, representing cell circumcenters
- class `ICONMultimeshNodes`: hidden mesh, representing the vertices of
a grid hierarchy

Both classes, together with the corresponding edge builders

- class `ICONTopologicalProcessorEdges`
- class `ICONTopologicalEncoderEdges`
- class `ICONTopologicalDecoderEdges`

are based on the mesh hierarchy that is reconstructed from an ICON mesh
file in NetCDF format, making use of the `refinement_level_v` and
`refinement_level_c` property contained therein.

- `refinement_level_v[vertex] = 0,1,2, ...`,
where 0 denotes the vertices of the base grid, ie. the icosahedron
including the step of root subdivision RXXB00.

- `refinement_level_c[cell]`: cell refinement level index such that
value 0 denotes the cells of the base grid, ie. the icosahedron
including the step of root subdivision RXXB00.

To avoid multiple runs of the reconstruction algorithm, a separate
`ICONNodes` instance is created and used by the builders, see the
following YAML example:

.. code:: yaml

nodes:
# ICON mesh
icon_mesh:
node_builder:
_target_: anemoi.graphs.nodes.ICONNodes
name: "icon_grid_0026_R03B07_G"
grid_filename: "icon_grid_0026_R03B07_G.nc"
max_level_multimesh: 3
max_level_dataset: 3
# Data nodes
data:
node_builder:
_target_: anemoi.graphs.nodes.ICONCellGridNodes
icon_mesh: "icon_mesh"
attributes: ${graph.attributes.nodes}
# Hidden nodes
hidden:
node_builder:
_target_: anemoi.graphs.nodes.ICONMultimeshNodes
icon_mesh: "icon_mesh"

edges:
# Processor configuration
- source_name: ${graph.hidden}
target_name: ${graph.hidden}
edge_builder:
_target_: anemoi.graphs.edges.ICONTopologicalProcessorEdges
icon_mesh: "icon_mesh"
attributes: ${graph.attributes.edges}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ dependencies = [
"healpy>=1.17",
"hydra-core>=1.3",
"matplotlib>=3.4",
"netcdf4",
"networkx>=3.1",
"plotly>=5.19",
"torch>=2.2",
"torch-geometric>=2.3.1,<2.5",
"trimesh>=4.1",
"typeguard",
]

optional-dependencies.all = [ ]
Expand Down
12 changes: 11 additions & 1 deletion src/anemoi/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@
# nor does it submit to any jurisdiction.

from .builder import CutOffEdges
from .builder import ICONTopologicalDecoderEdges
from .builder import ICONTopologicalEncoderEdges
from .builder import ICONTopologicalProcessorEdges
from .builder import KNNEdges
from .builder import MultiScaleEdges

fprill marked this conversation as resolved.
Show resolved Hide resolved
__all__ = ["KNNEdges", "CutOffEdges", "MultiScaleEdges"]
__all__ = [
"KNNEdges",
"CutOffEdges",
"MultiScaleEdges",
"ICONTopologicalProcessorEdges",
"ICONTopologicalEncoderEdges",
"ICONTopologicalDecoderEdges",
]
137 changes: 134 additions & 3 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import networkx as nx
import numpy as np
import scipy
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
Expand Down Expand Up @@ -80,10 +81,8 @@ def get_edge_index(self, graph: HeteroData) -> torch.Tensor:
source_nodes, target_nodes = self.prepare_node_data(graph)

adjmat = self.get_adjacency_matrix(source_nodes, target_nodes)

# Get source & target indices of the edges
edge_index = np.stack([adjmat.col, adjmat.row], axis=0)

return torch.from_numpy(edge_index).to(torch.int32)

def register_edges(self, graph: HeteroData) -> HeteroData:
Expand Down Expand Up @@ -381,7 +380,13 @@ class MultiScaleEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes, StretchedTriNodes]
VALID_NODES = [
TriNodes,
HexNodes,
LimitedAreaTriNodes,
LimitedAreaHexNodes,
StretchedTriNodes,
]

def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs):
super().__init__(source_name, target_name)
Expand Down Expand Up @@ -444,3 +449,129 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes."

return super().update_graph(graph, attrs_config)


class ICONTopologicalBaseEdgeBuilder(BaseEdgeBuilder):
"""Base class for computing edges based on ICON grid topology.

Attributes
----------
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
icon_mesh : str
The name of the ICON mesh (defines both the processor mesh and the data)
"""

def __init__(
self,
source_name: str,
target_name: str,
icon_mesh: str,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
self.icon_mesh = icon_mesh
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)

def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData:
"""Update the graph with the edges."""
assert self.icon_mesh is not None, f"{self.__class__.__name__} requires initialized icon_mesh."
self.icon_sub_graph = graph[self.icon_mesh][self.sub_graph_address]
return super().update_graph(graph, attrs_config)

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Parameters
----------
source_nodes : NodeStorage
The source nodes.
target_nodes : NodeStorage
The target nodes.
"""
LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}")
nrows = self.icon_sub_graph.num_edges
adj_matrix = scipy.sparse.coo_matrix(
(
np.ones(nrows),
(
self.icon_sub_graph.edge_vertices[:, self.vertex_index[0]],
self.icon_sub_graph.edge_vertices[:, self.vertex_index[1]],
),
)
)
return adj_matrix


class ICONTopologicalProcessorEdges(ICONTopologicalBaseEdgeBuilder):
"""Computes edges based on ICON grid topology: processor grid built
from ICON grid vertices.
"""

def __init__(
self,
source_name: str,
target_name: str,
icon_mesh: str,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
self.sub_graph_address = "_multi_mesh"
self.vertex_index = (1, 0)
super().__init__(
source_name,
target_name,
icon_mesh,
source_mask_attr_name,
target_mask_attr_name,
)


class ICONTopologicalEncoderEdges(ICONTopologicalBaseEdgeBuilder):
"""Computes encoder edges based on ICON grid topology: ICON cell
circumcenters for mapped onto processor grid built from ICON grid
vertices.
"""

def __init__(
self,
source_name: str,
target_name: str,
icon_mesh: str,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
self.sub_graph_address = "_cell_grid"
self.vertex_index = (1, 0)
super().__init__(
source_name,
target_name,
icon_mesh,
source_mask_attr_name,
target_mask_attr_name,
)


class ICONTopologicalDecoderEdges(ICONTopologicalBaseEdgeBuilder):
"""Computes encoder edges based on ICON grid topology: mapping from
processor grid built from ICON grid vertices onto ICON cell
circumcenters.
"""

def __init__(
self,
source_name: str,
target_name: str,
icon_mesh: str,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
self.sub_graph_address = "_cell_grid"
self.vertex_index = (0, 1)
super().__init__(
source_name,
target_name,
icon_mesh,
source_mask_attr_name,
target_mask_attr_name,
)
Loading
Loading