Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PUBLIC: Add predecessor_pointers_to_permutation_matrix and permutation_matrix_to_predecessor_pointers methods to probing. #129

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion clrs/_src/probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_Type = specs.Type
_OutputClass = specs.OutputClass

_Array = np.ndarray
_Array = np.ndarray | jax.Array
_Data = Union[_Array, List[_Array]]
_DataOrType = Union[_Data, str]

Expand Down Expand Up @@ -312,6 +312,94 @@ def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
return probe


@functools.partial(jnp.vectorize, signature='(n)->(n,n)')
def predecessor_pointers_to_permutation_matrix(
pointers: jnp.ndarray) -> jnp.ndarray:
"""Converts predecessor pointers to a permutation matrix.

This function assumes that the pointers represent a linear order of the nodes
(akin to a linked list), where each node points to its predecessor and the
first node points to itself. It returns a permutation matrix `P` that sorts
the nodes into the order implied by the pointers.

Example:
```
pointers = [2, 1, 1]
P = [[0, 1, 0],
[0, 0, 1],
[1, 0, 0]]
```

Args:
pointers: array of shape [N] containing pointers. The pointers are assumed
to describe a linear order such that `pointers[i]` is the predecessor
of node `i`.

Returns:
Permutation matrix `P` of shape [N, N]. Given node features `x` of shape
[N, F], `P @ x` returns sorted node features.
"""
# Find the index of the last node: it's the node that no other node points to.
nb_nodes = pointers.shape[-1]
pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes)
last = pointers_one_hot.sum(-2).argmin()

# Initialize permutation matrix with zeros.
perm = jnp.zeros([nb_nodes, nb_nodes])

for i in range(nb_nodes - 1, -1, -1):
# perm[i, last] = 1
perm += (
jax.nn.one_hot(i, nb_nodes)[..., None] * jax.nn.one_hot(last, nb_nodes))
last = pointers[last]

return perm


@functools.partial(jnp.vectorize, signature='(n,n)->(n)')
def permutation_matrix_to_predecessor_pointers(
perm: jnp.ndarray) -> jnp.ndarray:
"""Converts a permutation matrix to predecessor pointers.

Given an [N, N] permutation matrix `P` that sorts a list of nodes, this
function returns predecessor pointers that encode the sorted order.

Example:
```
P = [[0, 1, 0],
[0, 0, 1],
[1, 0, 0]]
pointers = [2, 1, 1]
```

Args:
perm: permutation matrix of shape [N, N].

Returns:
An array of shape [N] containing predecessor pointers.
"""
nb_nodes = perm.shape[-1]

# Initialize pointers to zeros.
pointers = jnp.zeros([nb_nodes], dtype=int)

# idx[i] is the index of the node at position i in the sorted order
idx = perm.argmax(-1)

# pointers[idx[0]] = idx[0]
pointers += idx[0] * jax.nn.one_hot(idx[0], nb_nodes)

for i in range(1, nb_nodes):
# pointers[idx[i]] = idx[i - 1]
pointers += idx[i - 1] * jax.nn.one_hot(idx[i], nb_nodes)

# Ensure that the pointers are in the valid range even if the input is badly
# formatted. This has no effect for well-formatted input.
pointers = jnp.minimum(pointers, nb_nodes - 1)

return pointers


@functools.partial(jnp.vectorize, signature='(n)->(n,n),(n)')
def predecessor_to_cyclic_predecessor_and_first(
pointers: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand Down