diff --git a/clrs/_src/probing.py b/clrs/_src/probing.py index 2bcef02b..d9e9ad4a 100644 --- a/clrs/_src/probing.py +++ b/clrs/_src/probing.py @@ -39,7 +39,7 @@ _Type = specs.Type _OutputClass = specs.OutputClass -_Array = np.ndarray +_Array = np.ndarray | jax.Array _Data = Union[_Array, List[_Array]] _DataOrType = Union[_Data, str] @@ -312,6 +312,94 @@ def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: return probe +@functools.partial(jnp.vectorize, signature='(n)->(n,n)') +def predecessor_pointers_to_permutation_matrix( + pointers: jnp.ndarray) -> jnp.ndarray: + """Converts predecessor pointers to a permutation matrix. + + This function assumes that the pointers represent a linear order of the nodes + (akin to a linked list), where each node points to its predecessor and the + first node points to itself. It returns a permutation matrix `P` that sorts + the nodes into the order implied by the pointers. + + Example: + ``` + pointers = [2, 1, 1] + P = [[0, 1, 0], + [0, 0, 1], + [1, 0, 0]] + ``` + + Args: + pointers: array of shape [N] containing pointers. The pointers are assumed + to describe a linear order such that `pointers[i]` is the predecessor + of node `i`. + + Returns: + Permutation matrix `P` of shape [N, N]. Given node features `x` of shape + [N, F], `P @ x` returns sorted node features. + """ + # Find the index of the last node: it's the node that no other node points to. + nb_nodes = pointers.shape[-1] + pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes) + last = pointers_one_hot.sum(-2).argmin() + + # Initialize permutation matrix with zeros. + perm = jnp.zeros([nb_nodes, nb_nodes]) + + for i in range(nb_nodes - 1, -1, -1): + # perm[i, last] = 1 + perm += ( + jax.nn.one_hot(i, nb_nodes)[..., None] * jax.nn.one_hot(last, nb_nodes)) + last = pointers[last] + + return perm + + +@functools.partial(jnp.vectorize, signature='(n,n)->(n)') +def permutation_matrix_to_predecessor_pointers( + perm: jnp.ndarray) -> jnp.ndarray: + """Converts a permutation matrix to predecessor pointers. + + Given an [N, N] permutation matrix `P` that sorts a list of nodes, this + function returns predecessor pointers that encode the sorted order. + + Example: + ``` + P = [[0, 1, 0], + [0, 0, 1], + [1, 0, 0]] + pointers = [2, 1, 1] + ``` + + Args: + perm: permutation matrix of shape [N, N]. + + Returns: + An array of shape [N] containing predecessor pointers. + """ + nb_nodes = perm.shape[-1] + + # Initialize pointers to zeros. + pointers = jnp.zeros([nb_nodes], dtype=int) + + # idx[i] is the index of the node at position i in the sorted order + idx = perm.argmax(-1) + + # pointers[idx[0]] = idx[0] + pointers += idx[0] * jax.nn.one_hot(idx[0], nb_nodes) + + for i in range(1, nb_nodes): + # pointers[idx[i]] = idx[i - 1] + pointers += idx[i - 1] * jax.nn.one_hot(idx[i], nb_nodes) + + # Ensure that the pointers are in the valid range even if the input is badly + # formatted. This has no effect for well-formatted input. + pointers = jnp.minimum(pointers, nb_nodes - 1) + + return pointers + + @functools.partial(jnp.vectorize, signature='(n)->(n,n),(n)') def predecessor_to_cyclic_predecessor_and_first( pointers: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: