Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] support post-processors #71

Merged
merged 25 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c3c2a4c
feat: support post-processors
JPXKQX Oct 25, 2024
a0e869b
update CHANGELOG.md
JPXKQX Oct 25, 2024
8a4014c
feat: ignore and docstrings
JPXKQX Oct 26, 2024
92c5b4f
fix: add tests
JPXKQX Oct 26, 2024
404be91
fix: minor updates
JPXKQX Oct 28, 2024
1a3635c
fix: import annotations
JPXKQX Oct 28, 2024
c742bf7
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Oct 28, 2024
dd5f187
fix: docstring & abstractmethod
JPXKQX Oct 28, 2024
877404b
Merge branch 'feature/remove-unconnected-nodes' of github.com:ecmwf/a…
JPXKQX Oct 28, 2024
1a27371
fix: homogeneize docstrings
JPXKQX Oct 28, 2024
744607e
fix: change default
JPXKQX Nov 6, 2024
ce5879c
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 6, 2024
56a48cc
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 8, 2024
274859d
fix: changelog
JPXKQX Nov 8, 2024
b42f32c
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 11, 2024
63b9473
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 11, 2024
6305489
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 13, 2024
445b241
fix: changelog
JPXKQX Nov 13, 2024
49a60c7
Merge branch 'develop' into feature/remove-unconnected-nodes
JPXKQX Nov 21, 2024
56b4bff
fix: new attribute shape (num_nodes, 1)
JPXKQX Nov 21, 2024
99a144f
fix: tests
JPXKQX Nov 21, 2024
8ad6c76
fix: tests shape
JPXKQX Nov 21, 2024
09148ab
feat: provide base class for MaskBasedProcessor
JPXKQX Nov 21, 2024
c2f8f18
feat: set mask as attribute
JPXKQX Nov 22, 2024
8a2d88d
fix: clean hidden attributes before apply post processors
JPXKQX Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ Please add your functional changes to the appropriate section in the PR.
Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...HEAD)

### Added
- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)
- feat: Add support for `post_processors` in the recipe. (#71)
- feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71)

## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08

Expand Down
27 changes: 27 additions & 0 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@ def clean(self, graph: HeteroData) -> HeteroData:

return graph

def post_process(self, graph: HeteroData) -> HeteroData:
"""Allow post-processing of the resulting graph.

This method applies any configured post-processors to the graph,
which can modify or enhance the graph structure or attributes.

Parameters
----------
graph : HeteroData
The graph to be post-processed.

Returns
-------
HeteroData
The post-processed graph.

Notes
-----
Post-processors are applied in the order they are specified in the configuration.
Each post-processor should implement an `update_graph` method that takes and returns a HeteroData object.
"""
for processor in self.config.get("post_processors", []):
graph = instantiate(processor).update_graph(graph)
HCookie marked this conversation as resolved.
Show resolved Hide resolved

return graph

def save(self, graph: HeteroData, save_path: Path, overwrite: bool = False) -> None:
"""Save the generated graph to the output path.

Expand Down Expand Up @@ -126,6 +152,7 @@ def create(self, save_path: Path | None = None, overwrite: bool = False) -> Hete
graph = HeteroData()
graph = self.update_graph(graph)
graph = self.clean(graph)
graph = self.post_process(graph)

if save_path is None:
LOGGER.warning("No output path specified. The graph will not be saved.")
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/graphs/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .post_process import RemoveUnconnectedNodes

__all__ = [RemoveUnconnectedNodes]
149 changes: 149 additions & 0 deletions src/anemoi/graphs/processors/post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import logging
from abc import ABC
from abc import abstractmethod

import torch
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)


class PostProcessor(ABC):

@abstractmethod
def update_graph(self, graph: HeteroData) -> HeteroData:
raise NotImplementedError(f"The {self.__class__.__name__} class does not implement the method update_graph().")


class BaseMaskingProcessor(PostProcessor, ABC):
"""Base class for mask based processor."""

def __init__(
self,
nodes_name: str,
save_mask_indices_to_attr: str | None = None,
) -> None:
self.nodes_name = nodes_name
self.save_mask_indices_to_attr = save_mask_indices_to_attr
self.mask: torch.Tensor = None

def removing_nodes(self, graph: HeteroData) -> HeteroData:
"""Remove nodes based on the mask passed."""
for attr_name in graph[self.nodes_name].node_attrs():
graph[self.nodes_name][attr_name] = graph[self.nodes_name][attr_name][self.mask]
JPXKQX marked this conversation as resolved.
Show resolved Hide resolved

return graph

def create_indices_mapper_from_mask(self) -> dict[int, int]:
return dict(zip(torch.where(self.mask)[0].tolist(), list(range(self.mask.sum()))))

def update_edge_indices(self, graph: HeteroData) -> HeteroData:
"""Update the edge indices to the new position of the nodes."""
idx_mapping = self.create_indices_mapper_from_mask()
for edges_name in graph.edge_types:
if edges_name[0] == self.nodes_name:
graph[edges_name].edge_index[0] = graph[edges_name].edge_index[0].apply_(idx_mapping.get)

if edges_name[2] == self.nodes_name:
graph[edges_name].edge_index[1] = graph[edges_name].edge_index[1].apply_(idx_mapping.get)

return graph

@abstractmethod
def compute_mask(self, graph: HeteroData) -> torch.Tensor: ...

def add_attribute(self, graph: HeteroData) -> HeteroData:
"""Add an attribute of the mask indices as node attribute."""
if self.save_mask_indices_to_attr is not None:
LOGGER.info(
f"An attribute {self.save_mask_indices_to_attr} has been added with the indices to mask the nodes from the original graph."
)
mask_indices = torch.where(self.mask)[0].reshape((graph[self.nodes_name].num_nodes, -1))
graph[self.nodes_name][self.save_mask_indices_to_attr] = mask_indices

return graph

def update_graph(self, graph: HeteroData) -> HeteroData:
"""Post-process the graph.

Parameters
----------
graph: HeteroData
The graph to post-process.

Returns
-------
HeteroData
The post-processed graph.
"""
self.mask = self.compute_mask(graph)
LOGGER.info(f"Removing {(~self.mask).sum()} nodes from {self.nodes_name}.")
graph = self.removing_nodes(graph)
graph = self.update_edge_indices(graph)
graph = self.add_attribute(graph)
return graph


class RemoveUnconnectedNodes(BaseMaskingProcessor):
"""Remove unconnected nodes in the graph.

Attributes
----------
nodes_name: str
Name of the unconnected nodes to remove.
ignore: str, optional
Name of an attribute to ignore when removing nodes. Nodes with
this attribute set to True will not be removed.
save_mask_indices_to_attr: str, optional
Name of the attribute to save the mask indices. If provided,
the indices of the kept nodes will be saved in this attribute.

Methods
-------
compute_mask(graph)
Compute the mask of the connected nodes.
prune_graph(graph, mask)
Prune the nodes with the specified mask.
add_attribute(graph, mask)
Add an attribute of the mask indices as node attribute.
update_graph(graph)
Post-process the graph.
"""

def __init__(
self,
nodes_name: str,
save_mask_indices_to_attr: str | None = None,
ignore: str | None = None,
) -> None:
super().__init__(nodes_name, save_mask_indices_to_attr)
self.ignore = ignore

def compute_mask(self, graph: HeteroData) -> torch.Tensor:
JPXKQX marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the mask of connected nodes."""
nodes = graph[self.nodes_name]
connected_mask = torch.zeros(nodes.num_nodes, dtype=torch.bool)

if self.ignore is not None:
LOGGER.info(f"The nodes with {self.ignore}=True will not be removed.")
connected_mask[nodes[self.ignore].bool().squeeze()] = True

for (source_name, _, target_name), edges in graph.edge_items():
if source_name == self.nodes_name:
connected_mask[edges.edge_index[0]] = True

if target_name == self.nodes_name:
connected_mask[edges.edge_index[1]] = True

return connected_mask
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def graph_with_nodes() -> HeteroData:
return graph


@pytest.fixture
def graph_with_isolated_nodes() -> HeteroData:
graph = HeteroData()
graph["test_nodes"].x = torch.tensor([[1], [2], [3], [4], [5], [6]])
graph["test_nodes"]["mask_attr"] = torch.tensor([[1], [1], [1], [0], [0], [0]], dtype=torch.bool)
graph["test_nodes", "to", "test_nodes"].edge_index = torch.tensor([[2, 3, 4], [1, 2, 3]])
return graph


@pytest.fixture
def graph_nodes_and_edges() -> HeteroData:
"""Graph with 1 set of nodes and edges."""
Expand Down
79 changes: 79 additions & 0 deletions tests/processors/test_post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from __future__ import annotations

import pytest
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.processors.post_process import RemoveUnconnectedNodes


def test_remove_unconnected_nodes(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(nodes_name="test_nodes", ignore=None, save_mask_indices_to_attr=None)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 4
assert torch.equal(graph["test_nodes"].x, torch.tensor([[2], [3], [4], [5]]))
assert "original_indices" not in graph["test_nodes"]


def test_remove_unconnected_nodes_with_indices_attr(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(
nodes_name="test_nodes", ignore=None, save_mask_indices_to_attr="original_indices"
)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 4
assert torch.equal(graph["test_nodes"].x, torch.tensor([[2], [3], [4], [5]]))
assert torch.equal(graph["test_nodes", "to", "test_nodes"].edge_index, torch.tensor([[1, 2, 3], [0, 1, 2]]))
assert torch.equal(graph["test_nodes"].original_indices, torch.tensor([[1], [2], [3], [4]]))


def test_remove_unconnected_nodes_with_ignore(graph_with_isolated_nodes: HeteroData):
processor = RemoveUnconnectedNodes(nodes_name="test_nodes", ignore="mask_attr", save_mask_indices_to_attr=None)

graph = processor.update_graph(graph_with_isolated_nodes)

assert graph["test_nodes"].num_nodes == 5
assert torch.equal(graph["test_nodes"].x, torch.tensor([[1], [2], [3], [4], [5]]))
assert torch.equal(graph["test_nodes", "to", "test_nodes"].edge_index, torch.tensor([[2, 3, 4], [1, 2, 3]]))


@pytest.mark.parametrize(
"nodes_name,ignore,save_mask_indices_to_attr",
[
("test_nodes", None, "original_indices"),
("test_nodes", "mask_attr", None),
("test_nodes", None, None),
],
)
def test_remove_unconnected_nodes_parametrized(
graph_with_isolated_nodes: HeteroData,
nodes_name: str,
ignore: str | None,
save_mask_indices_to_attr: str | None,
):
processor = RemoveUnconnectedNodes(
nodes_name=nodes_name, ignore=ignore, save_mask_indices_to_attr=save_mask_indices_to_attr
)

graph = processor.update_graph(graph_with_isolated_nodes)

assert isinstance(graph, HeteroData)
pruned_nodes = 4 if ignore is None else 5
assert graph[nodes_name].num_nodes == pruned_nodes

if save_mask_indices_to_attr:
assert save_mask_indices_to_attr in graph[nodes_name]
assert graph[nodes_name][save_mask_indices_to_attr].ndim == 2
else:
assert graph[nodes_name].node_attrs() == graph_with_isolated_nodes[nodes_name].node_attrs()
Loading