Skip to content

Commit

Permalink
refactor: unified the three ICON Edgebuilders and the two Nodebuilders.
Browse files Browse the repository at this point in the history
  • Loading branch information
fprill committed Sep 30, 2024
1 parent f6d63f0 commit 8efd2c6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 90 deletions.
101 changes: 33 additions & 68 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
return super().update_graph(graph, attrs_config)


class ICONTopologicalProcessorEdges(BaseEdgeBuilder):
"""Computes edges based on ICON grid topology.
class ICONTopologicalBaseEdgeBuilder(BaseEdgeBuilder):
"""Base class for computing edges based on ICON grid topology.
Attributes
----------
Expand All @@ -372,7 +372,7 @@ def __init__(self, source_name: str, target_name: str, icon_mesh: str):

def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData:
"""Update the graph with the edges."""
self.multi_mesh = graph[self.icon_mesh]["_multi_mesh"]
self.icon_sub_graph = graph[self.icon_mesh][self.sub_graph_address]
return super().update_graph(graph, attrs_config)

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
Expand All @@ -384,84 +384,49 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
The target nodes.
"""
LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}")
nrows = self.multi_mesh.num_edges
nrows = self.icon_sub_graph.num_edges
adj_matrix = scipy.sparse.coo_matrix(
(np.ones(nrows), (self.multi_mesh.edge_vertices[:, 1], self.multi_mesh.edge_vertices[:, 0]))
(
np.ones(nrows),
(
self.icon_sub_graph.edge_vertices[:, self.vertex_index[0]],
self.icon_sub_graph.edge_vertices[:, self.vertex_index[1]],
),
)
)
return adj_matrix


class ICONTopologicalEncoderEdges(BaseEdgeBuilder):
"""Computes edges based on ICON grid topology.
Attributes
----------
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
icon_mesh : str
The name of the ICON mesh (defines both the processor mesh and the data)
class ICONTopologicalProcessorEdges(ICONTopologicalBaseEdgeBuilder):
"""Computes edges based on ICON grid topology: processor grid built
from ICON grid vertices.
"""

def __init__(self, source_name: str, target_name: str, icon_mesh: str):
self.icon_mesh = icon_mesh
super().__init__(source_name, target_name)
self.sub_graph_address = "_multi_mesh"
self.vertex_index = (1, 0)
super().__init__(source_name, target_name, icon_mesh)

def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData:
"""Update the graph with the edges."""
self.cell_data_grid = graph[self.icon_mesh]["_cell_grid"]
return super().update_graph(graph, attrs_config)

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Parameters
----------
source_nodes : NodeStorage
The source nodes.
target_nodes : NodeStorage
The target nodes.
"""
LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}")
nrows = self.cell_data_grid.num_edges
adj_matrix = scipy.sparse.coo_matrix(
(np.ones(nrows), (self.cell_data_grid.edge_vertices[:, 1], self.cell_data_grid.edge_vertices[:, 0]))
)
return adj_matrix
class ICONTopologicalEncoderEdges(ICONTopologicalBaseEdgeBuilder):
"""Computes encoder edges based on ICON grid topology: ICON cell
circumcenters for mapped onto processor grid built from ICON grid
vertices.
"""

def __init__(self, source_name: str, target_name: str, icon_mesh: str):
self.sub_graph_address = "_cell_grid"
self.vertex_index = (1, 0)
super().__init__(source_name, target_name, icon_mesh)

class ICONTopologicalDecoderEdges(BaseEdgeBuilder):
"""Computes edges based on ICON grid topology.

Attributes
----------
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
icon_mesh : str
The name of the ICON mesh (defines both the processor mesh and the data)
class ICONTopologicalDecoderEdges(ICONTopologicalBaseEdgeBuilder):
"""Computes encoder edges based on ICON grid topology: mapping from
processor grid built from ICON grid vertices onto ICON cell
circumcenters.
"""

def __init__(self, source_name: str, target_name: str, icon_mesh: str):
self.icon_mesh = icon_mesh
super().__init__(source_name, target_name)

def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData:
"""Update the graph with the edges."""
self.cell_data_grid = graph[self.icon_mesh]["_cell_grid"]
return super().update_graph(graph, attrs_config)

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Parameters
----------
source_nodes : NodeStorage
The source nodes.
target_nodes : NodeStorage
The target nodes.
"""
LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}")
nrows = self.cell_data_grid.num_edges
adj_matrix = scipy.sparse.coo_matrix(
(np.ones(nrows), (self.cell_data_grid.edge_vertices[:, 0], self.cell_data_grid.edge_vertices[:, 1]))
)
return adj_matrix
self.sub_graph_address = "_cell_grid"
self.vertex_index = (0, 1)
super().__init__(source_name, target_name, icon_mesh)
39 changes: 17 additions & 22 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
return super().register_attributes(graph, config)


class ICONMultimeshNodes(BaseNodeBuilder):
"""Processor mesh based on an ICON grid.
class ICONTopologicalBaseNodeBuilder(BaseNodeBuilder):
"""Base class for data mesh or processor mesh based on an ICON grid.
Parameters
----------
Expand All @@ -321,35 +321,30 @@ def __init__(self, name: str, icon_mesh: str) -> None:

def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData:
"""Update the graph with new nodes."""
self.multi_mesh = graph[self.icon_mesh]["_multi_mesh"]
self.icon_sub_graph = graph[self.icon_mesh][self.sub_graph_address]
return super().update_graph(graph, attr_config)

def get_coordinates(self) -> torch.Tensor:
return torch.from_numpy(self.multi_mesh.nodeset.gc_vertices.astype(np.float32)).fliplr()

class ICONMultimeshNodes(ICONTopologicalBaseNodeBuilder):
"""Processor mesh based on an ICON grid."""

class ICONCellGridNodes(BaseNodeBuilder):
"""Data mesh based on an ICON grid.
def __init__(self, name: str, icon_mesh: str) -> None:
self.sub_graph_address = "_multi_mesh"
super().__init__(name, icon_mesh)

Parameters
----------
name : str
key for the nodes in the HeteroData graph object.
icon_mesh : str
key corresponding to the ICON mesh (cells and vertices).
"""
def get_coordinates(self) -> torch.Tensor:
return torch.from_numpy(self.icon_sub_graph.nodeset.gc_vertices.astype(np.float32)).fliplr()

def __init__(self, name: str, icon_mesh: str) -> None:
self.icon_mesh = icon_mesh
super().__init__(name)

def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData:
"""Update the graph with new nodes."""
self.cell_grid = graph[self.icon_mesh]["_cell_grid"]
return super().update_graph(graph, attr_config)
class ICONCellGridNodes(ICONTopologicalBaseNodeBuilder):
"""Data mesh based on an ICON grid."""

def __init__(self, name: str, icon_mesh: str) -> None:
self.sub_graph_address = "_cell_grid"
super().__init__(name, icon_mesh)

def get_coordinates(self) -> torch.Tensor:
return torch.from_numpy(self.cell_grid.nodeset[0].gc_vertices.astype(np.float32)).fliplr()
return torch.from_numpy(self.icon_sub_graph.nodeset[0].gc_vertices.astype(np.float32)).fliplr()


class HEALPixNodes(BaseNodeBuilder):
Expand Down

0 comments on commit 8efd2c6

Please sign in to comment.