Skip to content

Commit

Permalink
D2-equivariant GATv2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 603380204
  • Loading branch information
Petar Veličković authored and copybara-github committed Feb 1, 2024
1 parent 8df5b7a commit 4fe98f9
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions clrs/_src/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,64 @@ def __call__( # pytype: disable=signature-mismatch # numpy-scalars
return ret, None # pytype: disable=bad-return-type # numpy-scalars


class GATv2FullD2(GATv2):
"""Graph Attention Network v2 with full adjacency matrix and D2 symmetry."""

def d2_forward(self,
node_fts: List[_Array],
edge_fts: List[_Array],
graph_fts: List[_Array],
adj_mat: _Array,
hidden: _Array,
**unused_kwargs) -> List[_Array]:
num_d2_actions = 4

d2_inverses = [
0, 1, 2, 3 # All members of D_2 are self-inverses!
]

d2_multiply = [
[0, 1, 2, 3],
[1, 0, 3, 2],
[2, 3, 0, 1],
[3, 2, 1, 0],
]

assert len(node_fts) == num_d2_actions
assert len(edge_fts) == num_d2_actions
assert len(graph_fts) == num_d2_actions

ret_nodes = []
adj_mat = jnp.ones_like(adj_mat)

for g in range(num_d2_actions):
emb_values = []
for h in range(num_d2_actions):
gh = d2_multiply[d2_inverses[g]][h]
node_features = jnp.concatenate(
(node_fts[g], node_fts[gh]),
axis=-1)
edge_features = jnp.concatenate(
(edge_fts[g], edge_fts[gh]),
axis=-1)
graph_features = jnp.concatenate(
(graph_fts[g], graph_fts[gh]),
axis=-1)
cell_embedding = super().__call__(
node_fts=node_features,
edge_fts=edge_features,
graph_fts=graph_features,
adj_mat=adj_mat,
hidden=hidden
)
emb_values.append(cell_embedding)
ret_nodes.append(
jnp.mean(jnp.stack(emb_values, axis=0), axis=0)
)

return ret_nodes


class GATv2Full(GATv2):
"""Graph Attention Network v2 with full adjacency matrix."""

Expand Down

0 comments on commit 4fe98f9

Please sign in to comment.