Skip to content

Commit

Permalink
feat: support post-processors
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Oct 25, 2024
1 parent d48ac7b commit c3c2a4c
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@ def clean(self, graph: HeteroData) -> HeteroData:

return graph

def post_process(self, graph: HeteroData) -> HeteroData:
"""Allow post-processing of the resulting graph.
This method applies any configured post-processors to the graph,
which can modify or enhance the graph structure or attributes.
Parameters
----------
graph : HeteroData
The graph to be post-processed.
Returns
-------
HeteroData
The post-processed graph.
Notes
-----
Post-processors are applied in the order they are specified in the configuration.
Each post-processor should implement an `update_graph` method that takes and returns a HeteroData object.
"""
for processor in self.config.get("post_processors", {}):
graph = instantiate(processor).update_graph(graph)

return graph

def save(self, graph: HeteroData, save_path: Path, overwrite: bool = False) -> None:
"""Save the generated graph to the output path.
Expand Down Expand Up @@ -125,6 +151,7 @@ def create(self, save_path: Path | None = None, overwrite: bool = False) -> Hete
"""
graph = HeteroData()
graph = self.update_graph(graph)
graph = self.post_process(graph)
graph = self.clean(graph)

if save_path is None:
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/graphs/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .post_process import RemoveUnconnectedNodes

__all__ = [RemoveUnconnectedNodes]
94 changes: 94 additions & 0 deletions src/anemoi/graphs/processors/post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import logging
from abc import ABC

import torch
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)


class PostProcessor(ABC):
def update_graph(self, graph: HeteroData) -> HeteroData:
raise NotImplementedError(f"The {self.__class__.__name__} class does not implement the method update_graph().")


class RemoveUnconnectedNodes(PostProcessor):
"""Remove unconnected nodes in the graph."""

def __init__(
self,
nodes_name: str,
ignore: str | None,
save_mask_indices_to_attr: str | None,
) -> None:
self.nodes_name = nodes_name
self.ignore = ignore
self.save_mask_indices_to_attr = save_mask_indices_to_attr

def compute_mask(self, graph: HeteroData) -> torch.Tensor:
nodes = graph[self.nodes_name]
connected_mask = torch.zeros(nodes.num_nodes, dtype=torch.bool)

for (source_name, _, target_name), edges in graph.edge_items():
if source_name == self.nodes_name:
connected_mask[edges.edge_index[0]] = True

if target_name == self.nodes_name:
connected_mask[edges.edge_index[1]] = True

return connected_mask

def removing_nodes(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
for attr_name in graph[self.nodes_name].node_attrs():
graph[self.nodes_name][attr_name] = graph[self.nodes_name][attr_name][mask]

return graph

def update_edge_indices(self, graph: HeteroData, idx_mapping: dict[int, int]) -> HeteroData:
for edges_name in graph.edge_types:
if edges_name[0] == self.nodes_name:
graph[edges_name].edge_index[0] = graph[edges_name].edge_index[0].apply_(idx_mapping.get)

if edges_name[2] == self.nodes_name:
graph[edges_name].edge_index[1] = graph[edges_name].edge_index[1].apply_(idx_mapping.get)

return graph

def prune_graph(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
LOGGER.info(f"Removing {(~mask).sum()} nodes from {self.nodes_name}.")

# Pruning nodes
graph = self.removing_nodes(graph, mask)

# Updating edge indices
idx_mapping = dict(zip(torch.where(mask)[0].tolist(), list(range(mask.sum()))))
graph = self.update_edge_indices(graph, idx_mapping)

return graph

def add_attribute(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
if self.save_mask_indices_to_attr is not None:
LOGGER.info(
f"An attribute {self.save_mask_indices_to_attr} has been added with the indices to mask the nodes from the original graph."
)
graph[self.nodes_name][self.save_mask_indices_to_attr] = torch.where(mask)[0]

return graph

def update_graph(self, graph: HeteroData) -> HeteroData:
connected_mask = self.compute_mask(graph)
graph = self.prune_graph(graph, connected_mask)
graph = self.add_attribute(graph, connected_mask)

return graph

0 comments on commit c3c2a4c

Please sign in to comment.