Skip to content

Commit

Permalink
harmonise with RVGP
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Dec 1, 2023
1 parent 3e5001e commit 91eba6e
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 284 deletions.
2 changes: 1 addition & 1 deletion MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ momentum: 0.9
#manifold/signal parameters
order: 2 # order to which to compute the directional derivatives
inner_product_features: False
diffusion: False
diffusion: True
frac_sampled_nb: -1 # fraction of neighbours to sample for gradient computation (if -1 then all neighbours)
include_positions: False # include positions as features (warning: this is untested!)
include_self: True # include vector at the center of feature
Expand Down
205 changes: 39 additions & 166 deletions MARBLE/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from MARBLE import utils # isort:skip


def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0):
def furthest_point_sampling(x, N=None, spacing=0.0, start_idx=0):
"""A greedy O(N^2) algorithm to do furthest points sampling
Args:
Expand All @@ -36,7 +36,7 @@ def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0):
perm: node indices of the N sampled points
lambdas: list of distances of furthest points
"""
if stop_crit == 0.0:
if spacing == 0.0:
return torch.arange(len(x)), None

D = utils.np2torch(pairwise_distances(x))
Expand All @@ -54,7 +54,7 @@ def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0):
ds = torch.minimum(ds, D[idx, :])

if N is None:
if lambdas[i] / diam < stop_crit:
if lambdas[i] / diam < spacing:
perm = perm[:i]
lambdas = lambdas[:i]
break
Expand Down Expand Up @@ -314,10 +314,13 @@ def normalize_sparse_matrix(sp_tensor):
return sp_tensor


def map_to_local_gauges(x, gauges, length_correction=False):
def global_to_local_frame(x, gauges, length_correction=False, reverse=False):
"""Transform signal into local coordinates."""

proj = torch.einsum("aij,ai->aj", gauges, x)
if reverse:
proj = torch.einsum("bji,bi->bj", gauges, x)
else:
proj = torch.einsum("bij,bi->bj", gauges, x)

if length_correction:
norm_x = x.norm(p=2, dim=1, keepdim=True)
Expand Down Expand Up @@ -379,71 +382,6 @@ def fit_graph(x, graph_type="cknn", par=1, delta=1.0):
return edge_index, edge_weight


# def find_nn(X, ind_query=None, nn=1, r=None, theiler=10, n_jobs=-1):
# """
# Find nearest neighbors of a point on the manifold

# Parameters
# ----------
# ind_query : 2d np array, list[2d np array]
# Index of points whose neighbors are needed.
# x : nxd array (dimensions are columns!)
# Coordinates of n points on a manifold in d-dimensional space.
# nn : int, optional
# Number of nearest neighbors. The default is 1.
# theiler : int
# Theiler exclusion. Do not include the points immediately before or
# after in time the query point as neighbours.
# n_jobs : int, optional
# Number of processors to use. The default is all.

# Returns
# -------
# dist : list[list]
# Distance of nearest neighbors.
# ind : list[list]
# Index of nearest neighbors.

# """

# if ind_query is None:
# ind_query = np.arange(len(X))
# elif isinstance(ind_query, list):
# ind_query = np.vstack(ind_query)

# #Fit neighbor estimator object
# kdt = KDTree(X, leaf_size=30, metric='euclidean')

# inputs = [kdt, X, r, nn, theiler]
# ind = utils.parallel_proc(nb_query,
# ind_query,
# inputs,
# desc="Computing neighbours...")

# return ind


# def nb_query(inputs, i):

# kdt, X, r, nn, theiler = inputs

# x_query = X[i]
# if r is not None:
# ind, dist = kdt.query_radius(x_query, r=r, return_distance=True, sort_results=True)
# ind = ind[0]
# dist = dist[0]
# else:
# # apparently, the outputs are reversed here compared to query_radius()
# _, ind = kdt.query(x_query, k=nn+2*theiler+1)

# #Theiler exclusion (points immediately before or after are not useful neighbours)
# ind = ind[np.abs(ind-i)>theiler][:nn]

# # edges = np.vstack()

# return ind


def is_connected(edge_index):
"""Check if it is connected."""
adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.shape[1]))
Expand All @@ -461,7 +399,8 @@ def compute_laplacian(data, normalization="rw"):
num_nodes=data.num_nodes,
)

return PyGu.to_dense_adj(edge_index, edge_attr=edge_attr).squeeze()
# return PyGu.to_dense_adj(edge_index, edge_attr=edge_attr).squeeze()
return torch.sparse_coo_tensor(edge_index, edge_attr).coalesce()


def compute_connection_laplacian(data, R, normalization="rw"):
Expand All @@ -471,15 +410,11 @@ def compute_connection_laplacian(data, R, normalization="rw"):
data: Pytorch geometric data object.
R (nxnxdxd): Connection matrices between all pairs of nodes. Default is None,
in case of a global coordinate system.
normalization: None, 'sym', 'rw'
normalization: None, 'rw'
1. None: No normalization
:math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
2. "sym"`: Symmetric normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}`
3. "rw"`: Random-walk normalization
2. "rw"`: Random-walk normalization
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
Returns:
Expand All @@ -489,7 +424,7 @@ def compute_connection_laplacian(data, R, normalization="rw"):
d = R.size()[0] // n

# unnormalised (combinatorial) laplacian, to be normalised later
L = compute_laplacian(data, normalization=None).to_sparse()
L = compute_laplacian(data, normalization=None)#.to_sparse()

# rearrange into block form (kron(L, ones(d,d)))
edge_index = utils.expand_edge_index(L.indices(), dim=d)
Expand All @@ -513,39 +448,26 @@ def compute_connection_laplacian(data, R, normalization="rw"):
deg_inv = deg_inv.repeat_interleave(d, dim=0)
Lc = torch.diag(deg_inv).to_sparse() @ Lc

elif normalization == "sym":
raise NotImplementedError

return Lc
return Lc.coalesce()


def compute_gauges(data, dim_man=None, n_geodesic_nb=10, n_workers=1):
r"""Orthonormal gauges for the tangent space at each node, and connection
matrices between each pair of adjacent nodes.
R is a block matrix, where the row index is the gauge we want to align to,
i.e., gauges(i) = R[i,j]@gauges(j).
R[i,j] is optimal rotation that minimises ||X - RY||_F computed by SVD:
X, Y = gauges[i].T, gauges[j].T
U, _, Vt = scipy.linalg.svd(X.T@Y)
R[i,j] = U@Vt
def compute_gauges(data, dim_man=None, n_geodesic_nb=10, processes=1):
"""Orthonormal gauges for the tangent space at each node.
Args:
data: Pytorch geometric data object.
n_geodesic_nb: number of geodesic neighbours. The default is 10.
processes: number of CPUs to use
Returns:
gauges (nxdimxdim matrix): Matrix containing dim unit vectors for each node.
Sigma: Singular valued
R (n*dimxn*dim): Connection matrices.
"""
X = data.pos.numpy().astype(np.float64)
A = PyGu.to_scipy_sparse_matrix(data.edge_index).tocsr()

# make chunks for data processing
sl = data._slice_dict["x"] # pylint: disable=protected-access

n = len(sl) - 1
X = [X[sl[i] : sl[i + 1]] for i in range(n)]
A = [A[sl[i] : sl[i + 1], :][:, sl[i] : sl[i + 1]] for i in range(n)]
Expand All @@ -555,7 +477,8 @@ def compute_gauges(data, dim_man=None, n_geodesic_nb=10, n_workers=1):

inputs = [X, A, dim_man, n_geodesic_nb]
out = utils.parallel_proc(
_compute_gauges, range(n), inputs, processes=n_workers, desc="Computing tangent spaces..."
_compute_gauges, range(n), inputs, processes=processes,
desc="\n---- Computing tangent spaces..."
)

gauges, Sigma = zip(*out)
Expand All @@ -565,21 +488,27 @@ def compute_gauges(data, dim_man=None, n_geodesic_nb=10, n_workers=1):


def _compute_gauges(inputs, i):
"""Helper function to compute_gauges()"""
X_chunks, A_chunks, dim_man, n_geodesic_nb = inputs

gauges, Sigma = tangent_frames(X_chunks[i], A_chunks[i], dim_man, n_geodesic_nb)

return gauges, Sigma


def compute_connections(data, gauges, n_workers=1):
r"""Find smallest rotations R between gauges pairs. It is assumed that the first
def compute_connections(data, gauges, processes=1):
"""Find smallest rotations R between gauges pairs. It is assumed that the first
row of edge_index is what we want to align to, i.e.,
gauges(i) = gauges(j)@R[i,j].T
R[i,j] is optimal rotation that minimises ||X - RY||_F computed by SVD:
X, Y = gauges[i].T, gauges[j].T
U, _, Vt = scipy.linalg.svd(X.T@Y)
R[i,j] = U@Vt
Args:
data: Pytorch geometric data object
gauges (n,d,d matrix): Orthogonal unit vectors for each node
processes: number of CPUs to use
Returns:
(n*dim,n*dim) matrix of rotation matrices
Expand All @@ -597,13 +526,15 @@ def compute_connections(data, gauges, n_workers=1):

inputs = [gauges, A, dim_man]
out = utils.parallel_proc(
_compute_connections, range(n), inputs, processes=n_workers, desc="Computing connections..."
_compute_connections, range(n), inputs, processes=processes,
desc="\n---- Computing connections..."
)

return utils.to_block_diag(out)


def _compute_connections(inputs, i):
"""helper function to compute_connections()"""
gauges_chunks, A_chunks, dim_man = inputs

R = connections(gauges_chunks[i], A_chunks[i], dim_man)
Expand All @@ -614,66 +545,6 @@ def _compute_connections(inputs, i):
return torch.sparse_coo_tensor(edge_index, R.flatten(), dtype=torch.float32).coalesce()


def scalar_diffusion(x, t, method="matrix_exp", par=None):
"""Scalar diffusion."""
if len(x.shape) == 1:
x = x.unsqueeze(1)

if method == "matrix_exp":
if par.is_sparse:
par = par.to_dense()
return torch.matrix_exp(-t * par.to_dense()).mm(x)

if method == "spectral":
assert (
isinstance(par, (list, tuple)) and len(par) == 2
), "For spectral method, par must be a tuple of \
eigenvalues, eigenvectors!"
evals, evecs = par

# Transform to spectral
x_spec = torch.mm(evecs.T, x)

# Diffuse
diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * t.unsqueeze(0))
x_diffuse_spec = diffusion_coefs * x_spec

# Transform back to per-vertex
return evecs.mm(x_diffuse_spec)

raise NotImplementedError


def vector_diffusion(x, t, method="spectral", Lc=None):
"""Vector diffusion."""
n, d = x.shape[0], x.shape[1]

if method == "spectral":
assert len(Lc) == 2, "Lc must be a tuple of eigenvalues, eigenvectors!"
nd = Lc[0].shape[0]
else:
nd = Lc.shape[0]

assert (
n * d % nd
) == 0, "Data dimension must be an integer multiple of the dimensions \
of the connection Laplacian!"

# vector diffusion with connection Laplacian
out = x.view(nd, -1)
out = scalar_diffusion(out, t, method, Lc)
out = out.view(x.shape)

# if normalise:
# assert par['L'] is not None, 'Need Laplacian for normalised diffusion!'
# x_abs = x.norm(dim=-1, p=2, keepdim=True)
# out_abs = scalar_diffusion(x_abs, t, method, par['L'])
# ind = scalar_diffusion(torch.ones(x.shape[0],1), t, method, par['L'])
# out = out*out_abs/(ind*out.norm(dim=-1, p=2, keepdim=True))

return out


def compute_eigendecomposition(A, k=None, eps=1e-8):
"""Eigendecomposition of a square matrix A.
Expand All @@ -688,11 +559,13 @@ def compute_eigendecomposition(A, k=None, eps=1e-8):
"""
if A is None:
return None

if k is not None and k >= A.shape[0]:
k = None

# Compute the eigenbasis

if k is None:
A = A.to_dense()
else:
indices, values, size = A.indices(), A.values(), A.size()
A = sp.coo_array((values, (indices[0], indices[1])), shape=size)

failcount = 0
while True:
try:
Expand All @@ -712,6 +585,6 @@ def compute_eigendecomposition(A, k=None, eps=1e-8):
raise ValueError("failed to compute eigendecomp") from e
failcount += 1
print("--- decomp failed; adding eps ===> count: " + str(failcount))
A += torch.eye(A.shape[0]) * (eps * 10 ** (failcount - 1))
A += sp.eye(A.shape[0]) * (eps * 10 ** (failcount - 1))

return evals, evecs
9 changes: 4 additions & 5 deletions MARBLE/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import nn
from torch_geometric.nn.conv import MessagePassing

from MARBLE import geometry as g
from MARBLE import smoothing as s


class Diffusion(nn.Module):
Expand All @@ -20,8 +20,7 @@ def forward(self, x, L, Lc=None, method="spectral"):
if method == "spectral":
assert (
len(L) == 2
), "L must be a matrix or a pair of eigenvalues \
and eigenvectors"
), "L must be a matrix or a pair of eigenvalues and eigenvectors"

# making sure diffusion times are positive
with torch.no_grad():
Expand All @@ -30,9 +29,9 @@ def forward(self, x, L, Lc=None, method="spectral"):
t = self.diffusion_time

if Lc is not None:
out = g.vector_diffusion(x, t, method, Lc)
out = s.vector_diffusion(x, t, Lc, L=L, method=method, normalise=True)
else:
out = [g.scalar_diffusion(x_, t, method, L) for x_ in x.T]
out = [s.scalar_diffusion(x_, t, method, L) for x_ in x.T]
out = torch.cat(out, axis=1)

return out
Expand Down
Loading

0 comments on commit 91eba6e

Please sign in to comment.