Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support icon graphs #53

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
100c251
feat: topology-based encoder/processor/decoder graphs derived from an…
fprill Sep 26, 2024
9848cf0
docs: attempt to set a correct license header.
fprill Sep 26, 2024
69ae2fa
doc: changed the docstring style to the Numpy-style instead of Google…
fprill Sep 30, 2024
d858d0e
refactor: changed variable name `verts` into `vertices`.
fprill Sep 30, 2024
5e21ff5
fix: changed float division to double slash operator.
fprill Sep 30, 2024
3d662d1
refactor: removed unnecessary `else` branch.
fprill Sep 30, 2024
cfd82a9
refactor: more descriptive variable name in for loop.
fprill Sep 30, 2024
48230ec
refactor: renamed variable (ignoring the minus sign of `phi`).
fprill Sep 30, 2024
6ad883f
refactor: changed name of temporary variable.
fprill Sep 30, 2024
1862a66
refactor: remove unnecessary `else´branch.
fprill Sep 30, 2024
6f6e501
refactor: added type annotation for completeness.
fprill Sep 30, 2024
9bc7ae0
refactor: remove default in function argument.
fprill Sep 30, 2024
0fae14d
refactor: change argument name to a more understandable name.
fprill Sep 30, 2024
a44fb01
refactor: remove redundant code.
fprill Sep 30, 2024
25630cc
refactor: more Pythonic if-else statement.
fprill Sep 30, 2024
612afe5
refactor: more Pythonic if-else statement.
fprill Sep 30, 2024
368427c
refactor: use more appropriate `LOGGER.debug` instead of verbosity flag.
fprill Sep 30, 2024
5c76e10
refactor: more appropriate variable name.
fprill Sep 30, 2024
f6d63f0
refactor: more appropriate variable name.
fprill Sep 30, 2024
8efd2c6
refactor: unified the three ICON Edgebuilders and the two Nodebuilders.
fprill Sep 30, 2024
1961150
refactor: more verbose but also clearer names for variables in mesh c…
fprill Sep 30, 2024
30dee38
remove: removed obsolete function `set_constant_edge_id`.
fprill Sep 30, 2024
92cb023
refactor: replaced the sequential ID counter by a UUID.
fprill Sep 30, 2024
9937a61
revert change of copyright notice
MeraX Oct 16, 2024
ca67c7d
Merge branch 'develop' into feature/support-icon-graphs
fprill Nov 6, 2024
75057c3
[fix] add encoder & processor edges to the __all__ variable
fprill Nov 6, 2024
dcff399
[refactor] move auxiliary functions to utils.py in graphs/generate/
fprill Nov 6, 2024
9b19fad
[doc] added empty torso for class documentation.
fprill Nov 6, 2024
611e264
[fix] fixed interfaces (masks).
fprill Nov 8, 2024
797a6a2
[refactor] remove edge attribute calculation from this PR.
fprill Nov 8, 2024
37cb87a
[chore] adjust copyright notice.
fprill Nov 8, 2024
90a1721
[doc] elaborate on icon mesh classes in rst file.
fprill Nov 8, 2024
b45e21e
Add Icon tests
MeraX Nov 11, 2024
86fc909
Merge remote-tracking branch 'github/develop' into feature/support-ic…
MeraX Nov 11, 2024
63ea1b0
update Change log
MeraX Nov 11, 2024
4aaa9ff
Add __future__ annotations
MeraX Nov 11, 2024
01a94f0
fix change log
MeraX Nov 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ dependencies = [
"healpy>=1.17",
"hydra-core>=1.3",
"matplotlib>=3.4",
"netcdf4",
"networkx>=3.1",
"plotly>=5.19",
"torch>=2.2",
"torch-geometric>=2.3.1,<2.5",
"trimesh>=4.1",
"typeguard",
]

optional-dependencies.all = [ ]
Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/graphs/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
# (C) Copyright 2024 ECMWF, DWD.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
# In applying this licence, the above institution do not waive the privileges
# and immunities granted to it by virtue of its status as an intergovernmental
# organisation nor does it submit to any jurisdiction.
#

import os
Expand Down
5 changes: 4 additions & 1 deletion src/anemoi/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .builder import CutOffEdges
from .builder import ICONTopologicalDecoderEdges
from .builder import ICONTopologicalEncoderEdges
from .builder import ICONTopologicalProcessorEdges
from .builder import KNNEdges
from .builder import MultiScaleEdges

fprill marked this conversation as resolved.
Show resolved Hide resolved
__all__ = ["KNNEdges", "CutOffEdges", "MultiScaleEdges"]
__all__ = ["KNNEdges", "CutOffEdges", "MultiScaleEdges", "ICONTopologicalDecoderEdges"]
117 changes: 115 additions & 2 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import networkx as nx
import numpy as np
import scipy
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
Expand Down Expand Up @@ -58,10 +59,8 @@ def get_edge_index(self, graph: HeteroData) -> torch.Tensor:
source_nodes, target_nodes = self.prepare_node_data(graph)

adjmat = self.get_adjacency_matrix(source_nodes, target_nodes)

# Get source & target indices of the edges
edge_index = np.stack([adjmat.col, adjmat.row], axis=0)

return torch.from_numpy(edge_index).to(torch.int32)

def register_edges(self, graph: HeteroData) -> HeteroData:
Expand Down Expand Up @@ -352,3 +351,117 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
self.node_type = graph[self.source_name].node_type

return super().update_graph(graph, attrs_config)


class ICONTopologicalProcessorEdges(BaseEdgeBuilder):
"""Computes edges based on ICON grid topology.

Attributes
----------
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
icon_mesh : str
The name of the ICON mesh (defines both the processor mesh and the data)
"""

def __init__(self, source_name: str, target_name: str, icon_mesh: str):
self.icon_mesh = icon_mesh
super().__init__(source_name, target_name)

def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData:
"""Update the graph with the edges."""
self.multi_mesh = graph[self.icon_mesh]["_multi_mesh"]
fprill marked this conversation as resolved.
Show resolved Hide resolved
return super().update_graph(graph, attrs_config)

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Parameters
----------
source_nodes : NodeStorage
The source nodes.
target_nodes : NodeStorage
The target nodes.
"""
LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}")
nrows = self.multi_mesh.num_edges
adj_matrix = scipy.sparse.coo_matrix(
(np.ones(nrows), (self.multi_mesh.edge_vertices[:, 1], self.multi_mesh.edge_vertices[:, 0]))
)
return adj_matrix


class ICONTopologicalEncoderEdges(BaseEdgeBuilder):
"""Computes edges based on ICON grid topology.

Attributes
----------
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
icon_mesh : str
The name of the ICON mesh (defines both the processor mesh and the data)
"""

def __init__(self, source_name: str, target_name: str, icon_mesh: str):
self.icon_mesh = icon_mesh
super().__init__(source_name, target_name)

def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData:
"""Update the graph with the edges."""
self.cell_data_grid = graph[self.icon_mesh]["_cell_grid"]
return super().update_graph(graph, attrs_config)

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Parameters
----------
source_nodes : NodeStorage
The source nodes.
target_nodes : NodeStorage
The target nodes.
"""
LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}")
nrows = self.cell_data_grid.num_edges
adj_matrix = scipy.sparse.coo_matrix(
(np.ones(nrows), (self.cell_data_grid.edge_vertices[:, 1], self.cell_data_grid.edge_vertices[:, 0]))
)
return adj_matrix


class ICONTopologicalDecoderEdges(BaseEdgeBuilder):
"""Computes edges based on ICON grid topology.

Attributes
----------
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
icon_mesh : str
The name of the ICON mesh (defines both the processor mesh and the data)
"""

def __init__(self, source_name: str, target_name: str, icon_mesh: str):
self.icon_mesh = icon_mesh
super().__init__(source_name, target_name)

def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData:
"""Update the graph with the edges."""
self.cell_data_grid = graph[self.icon_mesh]["_cell_grid"]
return super().update_graph(graph, attrs_config)

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Parameters
----------
source_nodes : NodeStorage
The source nodes.
target_nodes : NodeStorage
The target nodes.
"""
LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}")
nrows = self.cell_data_grid.num_edges
adj_matrix = scipy.sparse.coo_matrix(
(np.ones(nrows), (self.cell_data_grid.edge_vertices[:, 0], self.cell_data_grid.edge_vertices[:, 1]))
)
return adj_matrix
Loading