diff --git a/clrs/_src/processors.py b/clrs/_src/processors.py index 2dda2a7c..a28414b0 100644 --- a/clrs/_src/processors.py +++ b/clrs/_src/processors.py @@ -289,6 +289,64 @@ def __call__( # pytype: disable=signature-mismatch # numpy-scalars return ret, None # pytype: disable=bad-return-type # numpy-scalars +class GATv2FullD2(GATv2): + """Graph Attention Network v2 with full adjacency matrix and D2 symmetry.""" + + def d2_forward(self, + node_fts: List[_Array], + edge_fts: List[_Array], + graph_fts: List[_Array], + adj_mat: _Array, + hidden: _Array, + **unused_kwargs) -> List[_Array]: + num_d2_actions = 4 + + d2_inverses = [ + 0, 1, 2, 3 # All members of D_2 are self-inverses! + ] + + d2_multiply = [ + [0, 1, 2, 3], + [1, 0, 3, 2], + [2, 3, 0, 1], + [3, 2, 1, 0], + ] + + assert len(node_fts) == num_d2_actions + assert len(edge_fts) == num_d2_actions + assert len(graph_fts) == num_d2_actions + + ret_nodes = [] + adj_mat = jnp.ones_like(adj_mat) + + for g in range(num_d2_actions): + emb_values = [] + for h in range(num_d2_actions): + gh = d2_multiply[d2_inverses[g]][h] + node_features = jnp.concatenate( + (node_fts[g], node_fts[gh]), + axis=-1) + edge_features = jnp.concatenate( + (edge_fts[g], edge_fts[gh]), + axis=-1) + graph_features = jnp.concatenate( + (graph_fts[g], graph_fts[gh]), + axis=-1) + cell_embedding = super().__call__( + node_fts=node_features, + edge_fts=edge_features, + graph_fts=graph_features, + adj_mat=adj_mat, + hidden=hidden + ) + emb_values.append(cell_embedding) + ret_nodes.append( + jnp.mean(jnp.stack(emb_values, axis=0), axis=0) + ) + + return ret_nodes + + class GATv2Full(GATv2): """Graph Attention Network v2 with full adjacency matrix."""