Skip to content

Commit

Permalink
feat: set mask as attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Nov 22, 2024
1 parent 09148ab commit c2f8f18
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/anemoi/graphs/processors/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,21 @@ def __init__(
) -> None:
self.nodes_name = nodes_name
self.save_mask_indices_to_attr = save_mask_indices_to_attr
self.mask: torch.Tensor = None

def removing_nodes(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
def removing_nodes(self, graph: HeteroData) -> HeteroData:
"""Remove nodes based on the mask passed."""
for attr_name in graph[self.nodes_name].node_attrs():
graph[self.nodes_name][attr_name] = graph[self.nodes_name][attr_name][mask]
graph[self.nodes_name][attr_name] = graph[self.nodes_name][attr_name][self.mask]

return graph

def create_indices_mapper_from_mask(self, mask: torch.Tensor) -> dict[int, int]:
return dict(zip(torch.where(mask)[0].tolist(), list(range(mask.sum()))))
def create_indices_mapper_from_mask(self) -> dict[int, int]:
return dict(zip(torch.where(self.mask)[0].tolist(), list(range(self.mask.sum()))))

def update_edge_indices(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
def update_edge_indices(self, graph: HeteroData) -> HeteroData:
"""Update the edge indices to the new position of the nodes."""
idx_mapping = self.create_indices_mapper_from_mask(mask)
idx_mapping = self.create_indices_mapper_from_mask()
for edges_name in graph.edge_types:
if edges_name[0] == self.nodes_name:
graph[edges_name].edge_index[0] = graph[edges_name].edge_index[0].apply_(idx_mapping.get)
Expand All @@ -62,13 +63,13 @@ def update_edge_indices(self, graph: HeteroData, mask: torch.Tensor) -> HeteroDa
@abstractmethod
def compute_mask(self, graph: HeteroData) -> torch.Tensor: ...

def add_attribute(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
def add_attribute(self, graph: HeteroData) -> HeteroData:
"""Add an attribute of the mask indices as node attribute."""
if self.save_mask_indices_to_attr is not None:
LOGGER.info(
f"An attribute {self.save_mask_indices_to_attr} has been added with the indices to mask the nodes from the original graph."
)
mask_indices = torch.where(mask)[0].reshape((graph[self.nodes_name].num_nodes, -1))
mask_indices = torch.where(self.mask)[0].reshape((graph[self.nodes_name].num_nodes, -1))
graph[self.nodes_name][self.save_mask_indices_to_attr] = mask_indices

return graph
Expand All @@ -86,11 +87,11 @@ def update_graph(self, graph: HeteroData) -> HeteroData:
HeteroData
The post-processed graph.
"""
valid_mask = self.compute_mask(graph)
LOGGER.info(f"Removing {(~valid_mask).sum()} nodes from {self.nodes_name}.")
graph = self.removing_nodes(graph, valid_mask)
graph = self.update_edge_indices(graph, valid_mask)
graph = self.add_attribute(graph, valid_mask)
self.mask = self.compute_mask(graph)
LOGGER.info(f"Removing {(~self.mask).sum()} nodes from {self.nodes_name}.")
graph = self.removing_nodes(graph)
graph = self.update_edge_indices(graph)
graph = self.add_attribute(graph)
return graph


Expand Down

0 comments on commit c2f8f18

Please sign in to comment.