Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] support post-processors #71

Merged
merged 25 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c3c2a4c
feat: support post-processors
JPXKQX Oct 25, 2024
a0e869b
update CHANGELOG.md
JPXKQX Oct 25, 2024
8a4014c
feat: ignore and docstrings
JPXKQX Oct 26, 2024
92c5b4f
fix: add tests
JPXKQX Oct 26, 2024
404be91
fix: minor updates
JPXKQX Oct 28, 2024
1a3635c
fix: import annotations
JPXKQX Oct 28, 2024
c742bf7
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Oct 28, 2024
dd5f187
fix: docstring & abstractmethod
JPXKQX Oct 28, 2024
877404b
Merge branch 'feature/remove-unconnected-nodes' of github.com:ecmwf/a…
JPXKQX Oct 28, 2024
1a27371
fix: homogeneize docstrings
JPXKQX Oct 28, 2024
744607e
fix: change default
JPXKQX Nov 6, 2024
ce5879c
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 6, 2024
56a48cc
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 8, 2024
274859d
fix: changelog
JPXKQX Nov 8, 2024
b42f32c
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 11, 2024
63b9473
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 11, 2024
6305489
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 13, 2024
445b241
fix: changelog
JPXKQX Nov 13, 2024
49a60c7
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 21, 2024
56b4bff
fix: new attribute shape (num_nodes, 1)
JPXKQX Nov 21, 2024
99a144f
fix: tests
JPXKQX Nov 21, 2024
8ad6c76
fix: tests shape
JPXKQX Nov 21, 2024
09148ab
feat: provide base class for MaskBasedProcessor
JPXKQX Nov 21, 2024
c2f8f18
feat: set mask as attribute
JPXKQX Nov 22, 2024
8a2d88d
fix: clean hidden attributes before apply post processors
JPXKQX Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Keep it human-readable, your future self will thank you!
- feat: New method update_graph(graph) in the GraphCreator class. (#60)
- feat: New class StretchedTriNodes to create a stretched mesh. (#51)
- feat: Expanded MultiScaleEdges to support multi-scale connections in stretched graphs. (#51)
- feat: Add support for `post_processors` in the recipe. (#71)
- feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71)
- feat: Add CONTRIBUTORS.md file. (#72)

### Changed
Expand Down
4 changes: 3 additions & 1 deletion src/anemoi/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
Expand Down
5 changes: 3 additions & 2 deletions src/anemoi/graphs/__main__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


from anemoi.utils.cli import cli_main
from anemoi.utils.cli import make_parser
Expand Down
27 changes: 27 additions & 0 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@ def clean(self, graph: HeteroData) -> HeteroData:

return graph

def post_process(self, graph: HeteroData) -> HeteroData:
"""Allow post-processing of the resulting graph.

This method applies any configured post-processors to the graph,
which can modify or enhance the graph structure or attributes.

Parameters
----------
graph : HeteroData
The graph to be post-processed.

Returns
-------
HeteroData
The post-processed graph.

Notes
-----
Post-processors are applied in the order they are specified in the configuration.
Each post-processor should implement an `update_graph` method that takes and returns a HeteroData object.
"""
for processor in self.config.get("post_processors", {}):
graph = instantiate(processor).update_graph(graph)
HCookie marked this conversation as resolved.
Show resolved Hide resolved

return graph

def save(self, graph: HeteroData, save_path: Path, overwrite: bool = False) -> None:
"""Save the generated graph to the output path.

Expand Down Expand Up @@ -125,6 +151,7 @@ def create(self, save_path: Path | None = None, overwrite: bool = False) -> Hete
"""
graph = HeteroData()
graph = self.update_graph(graph)
graph = self.post_process(graph)
graph = self.clean(graph)

if save_path is None:
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/graphs/processors/__init__.py
JPXKQX marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .post_process import RemoveUnconnectedNodes

__all__ = [RemoveUnconnectedNodes]
140 changes: 140 additions & 0 deletions src/anemoi/graphs/processors/post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import logging
from abc import ABC
from abc import abstractmethod

import torch
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)


class PostProcessor(ABC):

@abstractmethod
def update_graph(self, graph: HeteroData) -> HeteroData:
raise NotImplementedError(f"The {self.__class__.__name__} class does not implement the method update_graph().")


class RemoveUnconnectedNodes(PostProcessor):
"""Remove unconnected nodes in the graph.

Attributes
----------
nodes_name: str
Name of the unconnected nodes to remove.
ignore: str, optional
Name of an attribute to ignore when removing nodes. Nodes with
this attribute set to True will not be removed.
save_mask_indices_to_attr: str, optional
Name of the attribute to save the mask indices. If provided,
the indices of the kept nodes will be saved in this attribute.

Methods
-------
compute_mask(graph)
Compute the mask of the connected nodes.
prune_graph(graph, mask)
Prune the nodes with the specified mask.
add_attribute(graph, mask)
Add an attribute of the mask indices as node attribute.
update_graph(graph)
Post-process the graph.
"""

def __init__(
self,
nodes_name: str,
ignore: str | None,
save_mask_indices_to_attr: str | None,
) -> None:
self.nodes_name = nodes_name
self.ignore = ignore
self.save_mask_indices_to_attr = save_mask_indices_to_attr

def compute_mask(self, graph: HeteroData) -> torch.Tensor:
JPXKQX marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the mask of connected nodes."""
nodes = graph[self.nodes_name]
connected_mask = torch.zeros(nodes.num_nodes, dtype=torch.bool)

if self.ignore is not None:
LOGGER.info(f"The nodes with {self.ignore}=True will not be removed.")
connected_mask[nodes[self.ignore].bool().squeeze()] = True

for (source_name, _, target_name), edges in graph.edge_items():
if source_name == self.nodes_name:
connected_mask[edges.edge_index[0]] = True

if target_name == self.nodes_name:
connected_mask[edges.edge_index[1]] = True

return connected_mask

def removing_nodes(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
"""Remove nodes based on the mask passed."""
for attr_name in graph[self.nodes_name].node_attrs():
graph[self.nodes_name][attr_name] = graph[self.nodes_name][attr_name][mask]

return graph

def update_edge_indices(self, graph: HeteroData, idx_mapping: dict[int, int]) -> HeteroData:
"""Update the edge indices to the new position of the nodes."""
for edges_name in graph.edge_types:
if edges_name[0] == self.nodes_name:
graph[edges_name].edge_index[0] = graph[edges_name].edge_index[0].apply_(idx_mapping.get)

if edges_name[2] == self.nodes_name:
graph[edges_name].edge_index[1] = graph[edges_name].edge_index[1].apply_(idx_mapping.get)

return graph

def prune_graph(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
"""Prune the nodes with the specified mask."""
LOGGER.info(f"Removing {(~mask).sum()} nodes from {self.nodes_name}.")

# Pruning nodes
graph = self.removing_nodes(graph, mask)

# Updating edge indices
idx_mapping = dict(zip(torch.where(mask)[0].tolist(), list(range(mask.sum()))))
graph = self.update_edge_indices(graph, idx_mapping)

return graph

def add_attribute(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
"""Add an attribute of the mask indices as node attribute."""
if self.save_mask_indices_to_attr is not None:
LOGGER.info(
f"An attribute {self.save_mask_indices_to_attr} has been added with the indices to mask the nodes from the original graph."
)
graph[self.nodes_name][self.save_mask_indices_to_attr] = torch.where(mask)[0]

return graph

def update_graph(self, graph: HeteroData) -> HeteroData:
"""Post-process the graph.

Parameters
----------
graph: HeteroData
The graph to post-process.

Returns
-------
HeteroData
The post-processed graph.
"""
connected_mask = self.compute_mask(graph)
graph = self.prune_graph(graph, connected_mask)
graph = self.add_attribute(graph, connected_mask)
return graph
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def graph_with_nodes() -> HeteroData:
return graph


@pytest.fixture
def graph_with_isolated_nodes() -> HeteroData:
graph = HeteroData()
graph["test_nodes"].x = torch.tensor([[1], [2], [3], [4], [5], [6]])
graph["test_nodes"]["mask_attr"] = torch.tensor([[1], [1], [1], [0], [0], [0]], dtype=torch.bool)
graph["test_nodes", "to", "test_nodes"].edge_index = torch.tensor([[2, 3, 4], [1, 2, 3]])
return graph


@pytest.fixture
def graph_nodes_and_edges() -> HeteroData:
"""Graph with 1 set of nodes and edges."""
Expand Down
78 changes: 78 additions & 0 deletions tests/processors/test_post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from __future__ import annotations

import pytest
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.processors.post_process import RemoveUnconnectedNodes


def test_remove_unconnected_nodes(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(nodes_name="test_nodes", ignore=None, save_mask_indices_to_attr=None)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 4
assert torch.equal(graph["test_nodes"].x, torch.tensor([[2], [3], [4], [5]]))
assert "original_indices" not in graph["test_nodes"]


def test_remove_unconnected_nodes_with_indices_attr(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(
nodes_name="test_nodes", ignore=None, save_mask_indices_to_attr="original_indices"
)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 4
assert torch.equal(graph["test_nodes"].x, torch.tensor([[2], [3], [4], [5]]))
assert torch.equal(graph["test_nodes", "to", "test_nodes"].edge_index, torch.tensor([[1, 2, 3], [0, 1, 2]]))
assert torch.equal(graph["test_nodes"].original_indices, torch.tensor([1, 2, 3, 4]))


def test_remove_unconnected_nodes_with_ignore(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(nodes_name="test_nodes", ignore="mask_attr", save_mask_indices_to_attr=None)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 5
assert torch.equal(graph["test_nodes"].x, torch.tensor([[1], [2], [3], [4], [5]]))
assert torch.equal(graph["test_nodes", "to", "test_nodes"].edge_index, torch.tensor([[2, 3, 4], [1, 2, 3]]))


@pytest.mark.parametrize(
"nodes_name,ignore,save_mask_indices_to_attr",
[
("test_nodes", None, "original_indices"),
("test_nodes", "mask_attr", None),
("test_nodes", None, None),
],
)
def test_remove_unconnected_nodes_parametrized(
graph_with_isolated_nodes: HeteroData,
nodes_name: str,
ignore: str | None,
save_mask_indices_to_attr: str | None,
):
processor = RemoveUnconnectedNodes(
nodes_name=nodes_name, ignore=ignore, save_mask_indices_to_attr=save_mask_indices_to_attr
)

graph = processor.update_graph(graph_with_isolated_nodes)

assert isinstance(graph, HeteroData)
pruned_nodes = 4 if ignore is None else 5
assert graph[nodes_name].num_nodes == pruned_nodes

if save_mask_indices_to_attr:
assert save_mask_indices_to_attr in graph[nodes_name]
else:
assert graph[nodes_name].node_attrs() == graph_with_isolated_nodes[nodes_name].node_attrs()
Loading