Skip to content

Commit

Permalink
Merge branch 'develop' into feature/list-of-edge-builders
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX authored Nov 22, 2024
2 parents fddb7d1 + 85c9c74 commit 35159a4
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Keep it human-readable, your future self will thank you!

### Added
- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)
- feat: Add support for `post_processors` in the recipe. (#71)
- feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71)
- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)
- feat: Support for multiple edge builders between two sets of nodes (#70)

## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08
Expand Down
27 changes: 27 additions & 0 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,32 @@ def clean(self, graph: HeteroData) -> HeteroData:

return graph

def post_process(self, graph: HeteroData) -> HeteroData:
"""Allow post-processing of the resulting graph.
This method applies any configured post-processors to the graph,
which can modify or enhance the graph structure or attributes.
Parameters
----------
graph : HeteroData
The graph to be post-processed.
Returns
-------
HeteroData
The post-processed graph.
Notes
-----
Post-processors are applied in the order they are specified in the configuration.
Each post-processor should implement an `update_graph` method that takes and returns a HeteroData object.
"""
for processor in self.config.get("post_processors", []):
graph = instantiate(processor).update_graph(graph)

return graph

def save(self, graph: HeteroData, save_path: Path, overwrite: bool = False) -> None:
"""Save the generated graph to the output path.
Expand Down Expand Up @@ -141,6 +167,7 @@ def create(self, save_path: Path | None = None, overwrite: bool = False) -> Hete
graph = HeteroData()
graph = self.update_graph(graph)
graph = self.clean(graph)
graph = self.post_process(graph)

if save_path is None:
LOGGER.warning("No output path specified. The graph will not be saved.")
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/graphs/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .post_process import RemoveUnconnectedNodes

__all__ = [RemoveUnconnectedNodes]
149 changes: 149 additions & 0 deletions src/anemoi/graphs/processors/post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import logging
from abc import ABC
from abc import abstractmethod

import torch
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)


class PostProcessor(ABC):

@abstractmethod
def update_graph(self, graph: HeteroData) -> HeteroData:
raise NotImplementedError(f"The {self.__class__.__name__} class does not implement the method update_graph().")


class BaseMaskingProcessor(PostProcessor, ABC):
"""Base class for mask based processor."""

def __init__(
self,
nodes_name: str,
save_mask_indices_to_attr: str | None = None,
) -> None:
self.nodes_name = nodes_name
self.save_mask_indices_to_attr = save_mask_indices_to_attr
self.mask: torch.Tensor = None

def removing_nodes(self, graph: HeteroData) -> HeteroData:
"""Remove nodes based on the mask passed."""
for attr_name in graph[self.nodes_name].node_attrs():
graph[self.nodes_name][attr_name] = graph[self.nodes_name][attr_name][self.mask]

return graph

def create_indices_mapper_from_mask(self) -> dict[int, int]:
return dict(zip(torch.where(self.mask)[0].tolist(), list(range(self.mask.sum()))))

def update_edge_indices(self, graph: HeteroData) -> HeteroData:
"""Update the edge indices to the new position of the nodes."""
idx_mapping = self.create_indices_mapper_from_mask()
for edges_name in graph.edge_types:
if edges_name[0] == self.nodes_name:
graph[edges_name].edge_index[0] = graph[edges_name].edge_index[0].apply_(idx_mapping.get)

if edges_name[2] == self.nodes_name:
graph[edges_name].edge_index[1] = graph[edges_name].edge_index[1].apply_(idx_mapping.get)

return graph

@abstractmethod
def compute_mask(self, graph: HeteroData) -> torch.Tensor: ...

def add_attribute(self, graph: HeteroData) -> HeteroData:
"""Add an attribute of the mask indices as node attribute."""
if self.save_mask_indices_to_attr is not None:
LOGGER.info(
f"An attribute {self.save_mask_indices_to_attr} has been added with the indices to mask the nodes from the original graph."
)
mask_indices = torch.where(self.mask)[0].reshape((graph[self.nodes_name].num_nodes, -1))
graph[self.nodes_name][self.save_mask_indices_to_attr] = mask_indices

return graph

def update_graph(self, graph: HeteroData) -> HeteroData:
"""Post-process the graph.
Parameters
----------
graph: HeteroData
The graph to post-process.
Returns
-------
HeteroData
The post-processed graph.
"""
self.mask = self.compute_mask(graph)
LOGGER.info(f"Removing {(~self.mask).sum()} nodes from {self.nodes_name}.")
graph = self.removing_nodes(graph)
graph = self.update_edge_indices(graph)
graph = self.add_attribute(graph)
return graph


class RemoveUnconnectedNodes(BaseMaskingProcessor):
"""Remove unconnected nodes in the graph.
Attributes
----------
nodes_name: str
Name of the unconnected nodes to remove.
ignore: str, optional
Name of an attribute to ignore when removing nodes. Nodes with
this attribute set to True will not be removed.
save_mask_indices_to_attr: str, optional
Name of the attribute to save the mask indices. If provided,
the indices of the kept nodes will be saved in this attribute.
Methods
-------
compute_mask(graph)
Compute the mask of the connected nodes.
prune_graph(graph, mask)
Prune the nodes with the specified mask.
add_attribute(graph, mask)
Add an attribute of the mask indices as node attribute.
update_graph(graph)
Post-process the graph.
"""

def __init__(
self,
nodes_name: str,
save_mask_indices_to_attr: str | None = None,
ignore: str | None = None,
) -> None:
super().__init__(nodes_name, save_mask_indices_to_attr)
self.ignore = ignore

def compute_mask(self, graph: HeteroData) -> torch.Tensor:
"""Compute the mask of connected nodes."""
nodes = graph[self.nodes_name]
connected_mask = torch.zeros(nodes.num_nodes, dtype=torch.bool)

if self.ignore is not None:
LOGGER.info(f"The nodes with {self.ignore}=True will not be removed.")
connected_mask[nodes[self.ignore].bool().squeeze()] = True

for (source_name, _, target_name), edges in graph.edge_items():
if source_name == self.nodes_name:
connected_mask[edges.edge_index[0]] = True

if target_name == self.nodes_name:
connected_mask[edges.edge_index[1]] = True

return connected_mask
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def graph_with_nodes() -> HeteroData:
return graph


@pytest.fixture
def graph_with_isolated_nodes() -> HeteroData:
graph = HeteroData()
graph["test_nodes"].x = torch.tensor([[1], [2], [3], [4], [5], [6]])
graph["test_nodes"]["mask_attr"] = torch.tensor([[1], [1], [1], [0], [0], [0]], dtype=torch.bool)
graph["test_nodes", "to", "test_nodes"].edge_index = torch.tensor([[2, 3, 4], [1, 2, 3]])
return graph


@pytest.fixture
def graph_nodes_and_edges() -> HeteroData:
"""Graph with 1 set of nodes and edges."""
Expand Down
79 changes: 79 additions & 0 deletions tests/processors/test_post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from __future__ import annotations

import pytest
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.processors.post_process import RemoveUnconnectedNodes


def test_remove_unconnected_nodes(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(nodes_name="test_nodes", ignore=None, save_mask_indices_to_attr=None)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 4
assert torch.equal(graph["test_nodes"].x, torch.tensor([[2], [3], [4], [5]]))
assert "original_indices" not in graph["test_nodes"]


def test_remove_unconnected_nodes_with_indices_attr(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(
nodes_name="test_nodes", ignore=None, save_mask_indices_to_attr="original_indices"
)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 4
assert torch.equal(graph["test_nodes"].x, torch.tensor([[2], [3], [4], [5]]))
assert torch.equal(graph["test_nodes", "to", "test_nodes"].edge_index, torch.tensor([[1, 2, 3], [0, 1, 2]]))
assert torch.equal(graph["test_nodes"].original_indices, torch.tensor([[1], [2], [3], [4]]))


def test_remove_unconnected_nodes_with_ignore(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(nodes_name="test_nodes", ignore="mask_attr", save_mask_indices_to_attr=None)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 5
assert torch.equal(graph["test_nodes"].x, torch.tensor([[1], [2], [3], [4], [5]]))
assert torch.equal(graph["test_nodes", "to", "test_nodes"].edge_index, torch.tensor([[2, 3, 4], [1, 2, 3]]))


@pytest.mark.parametrize(
"nodes_name,ignore,save_mask_indices_to_attr",
[
("test_nodes", None, "original_indices"),
("test_nodes", "mask_attr", None),
("test_nodes", None, None),
],
)
def test_remove_unconnected_nodes_parametrized(
graph_with_isolated_nodes: HeteroData,
nodes_name: str,
ignore: str | None,
save_mask_indices_to_attr: str | None,
):
processor = RemoveUnconnectedNodes(
nodes_name=nodes_name, ignore=ignore, save_mask_indices_to_attr=save_mask_indices_to_attr
)

graph = processor.update_graph(graph_with_isolated_nodes)

assert isinstance(graph, HeteroData)
pruned_nodes = 4 if ignore is None else 5
assert graph[nodes_name].num_nodes == pruned_nodes

if save_mask_indices_to_attr:
assert save_mask_indices_to_attr in graph[nodes_name]
assert graph[nodes_name][save_mask_indices_to_attr].ndim == 2
else:
assert graph[nodes_name].node_attrs() == graph_with_isolated_nodes[nodes_name].node_attrs()

0 comments on commit 35159a4

Please sign in to comment.