diff --git a/MARBLE/default_params.yaml b/MARBLE/default_params.yaml index 8e8b7be6..8b75d9bc 100644 --- a/MARBLE/default_params.yaml +++ b/MARBLE/default_params.yaml @@ -7,7 +7,7 @@ momentum: 0.9 #manifold/signal parameters order: 2 # order to which to compute the directional derivatives inner_product_features: False -diffusion: False +diffusion: True frac_sampled_nb: -1 # fraction of neighbours to sample for gradient computation (if -1 then all neighbours) include_positions: False # include positions as features (warning: this is untested!) include_self: True # include vector at the center of feature diff --git a/MARBLE/geometry.py b/MARBLE/geometry.py index f30b6d6d..662bc66a 100644 --- a/MARBLE/geometry.py +++ b/MARBLE/geometry.py @@ -23,7 +23,7 @@ from MARBLE import utils # isort:skip -def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0): +def furthest_point_sampling(x, N=None, spacing=0.0, start_idx=0): """A greedy O(N^2) algorithm to do furthest points sampling Args: @@ -36,7 +36,7 @@ def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0): perm: node indices of the N sampled points lambdas: list of distances of furthest points """ - if stop_crit == 0.0: + if spacing == 0.0: return torch.arange(len(x)), None D = utils.np2torch(pairwise_distances(x)) @@ -54,7 +54,7 @@ def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0): ds = torch.minimum(ds, D[idx, :]) if N is None: - if lambdas[i] / diam < stop_crit: + if lambdas[i] / diam < spacing: perm = perm[:i] lambdas = lambdas[:i] break @@ -314,10 +314,13 @@ def normalize_sparse_matrix(sp_tensor): return sp_tensor -def map_to_local_gauges(x, gauges, length_correction=False): +def global_to_local_frame(x, gauges, length_correction=False, reverse=False): """Transform signal into local coordinates.""" - proj = torch.einsum("aij,ai->aj", gauges, x) + if reverse: + proj = torch.einsum("bji,bi->bj", gauges, x) + else: + proj = torch.einsum("bij,bi->bj", gauges, x) if length_correction: norm_x = x.norm(p=2, dim=1, keepdim=True) @@ -379,71 +382,6 @@ def fit_graph(x, graph_type="cknn", par=1, delta=1.0): return edge_index, edge_weight -# def find_nn(X, ind_query=None, nn=1, r=None, theiler=10, n_jobs=-1): -# """ -# Find nearest neighbors of a point on the manifold - -# Parameters -# ---------- -# ind_query : 2d np array, list[2d np array] -# Index of points whose neighbors are needed. -# x : nxd array (dimensions are columns!) -# Coordinates of n points on a manifold in d-dimensional space. -# nn : int, optional -# Number of nearest neighbors. The default is 1. -# theiler : int -# Theiler exclusion. Do not include the points immediately before or -# after in time the query point as neighbours. -# n_jobs : int, optional -# Number of processors to use. The default is all. - -# Returns -# ------- -# dist : list[list] -# Distance of nearest neighbors. -# ind : list[list] -# Index of nearest neighbors. - -# """ - -# if ind_query is None: -# ind_query = np.arange(len(X)) -# elif isinstance(ind_query, list): -# ind_query = np.vstack(ind_query) - -# #Fit neighbor estimator object -# kdt = KDTree(X, leaf_size=30, metric='euclidean') - -# inputs = [kdt, X, r, nn, theiler] -# ind = utils.parallel_proc(nb_query, -# ind_query, -# inputs, -# desc="Computing neighbours...") - -# return ind - - -# def nb_query(inputs, i): - -# kdt, X, r, nn, theiler = inputs - -# x_query = X[i] -# if r is not None: -# ind, dist = kdt.query_radius(x_query, r=r, return_distance=True, sort_results=True) -# ind = ind[0] -# dist = dist[0] -# else: -# # apparently, the outputs are reversed here compared to query_radius() -# _, ind = kdt.query(x_query, k=nn+2*theiler+1) - -# #Theiler exclusion (points immediately before or after are not useful neighbours) -# ind = ind[np.abs(ind-i)>theiler][:nn] - -# # edges = np.vstack() - -# return ind - - def is_connected(edge_index): """Check if it is connected.""" adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.shape[1])) @@ -461,7 +399,8 @@ def compute_laplacian(data, normalization="rw"): num_nodes=data.num_nodes, ) - return PyGu.to_dense_adj(edge_index, edge_attr=edge_attr).squeeze() + # return PyGu.to_dense_adj(edge_index, edge_attr=edge_attr).squeeze() + return torch.sparse_coo_tensor(edge_index, edge_attr).coalesce() def compute_connection_laplacian(data, R, normalization="rw"): @@ -471,15 +410,11 @@ def compute_connection_laplacian(data, R, normalization="rw"): data: Pytorch geometric data object. R (nxnxdxd): Connection matrices between all pairs of nodes. Default is None, in case of a global coordinate system. - normalization: None, 'sym', 'rw' + normalization: None, 'rw' 1. None: No normalization :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` - 2. "sym"`: Symmetric normalization - :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} - \mathbf{D}^{-1/2}` - - 3. "rw"`: Random-walk normalization + 2. "rw"`: Random-walk normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` Returns: @@ -489,7 +424,7 @@ def compute_connection_laplacian(data, R, normalization="rw"): d = R.size()[0] // n # unnormalised (combinatorial) laplacian, to be normalised later - L = compute_laplacian(data, normalization=None).to_sparse() + L = compute_laplacian(data, normalization=None)#.to_sparse() # rearrange into block form (kron(L, ones(d,d))) edge_index = utils.expand_edge_index(L.indices(), dim=d) @@ -513,39 +448,26 @@ def compute_connection_laplacian(data, R, normalization="rw"): deg_inv = deg_inv.repeat_interleave(d, dim=0) Lc = torch.diag(deg_inv).to_sparse() @ Lc - elif normalization == "sym": - raise NotImplementedError - - return Lc + return Lc.coalesce() -def compute_gauges(data, dim_man=None, n_geodesic_nb=10, n_workers=1): - r"""Orthonormal gauges for the tangent space at each node, and connection - matrices between each pair of adjacent nodes. - - R is a block matrix, where the row index is the gauge we want to align to, - i.e., gauges(i) = R[i,j]@gauges(j). - - R[i,j] is optimal rotation that minimises ||X - RY||_F computed by SVD: - X, Y = gauges[i].T, gauges[j].T - U, _, Vt = scipy.linalg.svd(X.T@Y) - R[i,j] = U@Vt +def compute_gauges(data, dim_man=None, n_geodesic_nb=10, processes=1): + """Orthonormal gauges for the tangent space at each node. Args: data: Pytorch geometric data object. n_geodesic_nb: number of geodesic neighbours. The default is 10. + processes: number of CPUs to use Returns: gauges (nxdimxdim matrix): Matrix containing dim unit vectors for each node. Sigma: Singular valued - R (n*dimxn*dim): Connection matrices. """ X = data.pos.numpy().astype(np.float64) A = PyGu.to_scipy_sparse_matrix(data.edge_index).tocsr() # make chunks for data processing sl = data._slice_dict["x"] # pylint: disable=protected-access - n = len(sl) - 1 X = [X[sl[i] : sl[i + 1]] for i in range(n)] A = [A[sl[i] : sl[i + 1], :][:, sl[i] : sl[i + 1]] for i in range(n)] @@ -555,7 +477,8 @@ def compute_gauges(data, dim_man=None, n_geodesic_nb=10, n_workers=1): inputs = [X, A, dim_man, n_geodesic_nb] out = utils.parallel_proc( - _compute_gauges, range(n), inputs, processes=n_workers, desc="Computing tangent spaces..." + _compute_gauges, range(n), inputs, processes=processes, + desc="\n---- Computing tangent spaces..." ) gauges, Sigma = zip(*out) @@ -565,21 +488,27 @@ def compute_gauges(data, dim_man=None, n_geodesic_nb=10, n_workers=1): def _compute_gauges(inputs, i): + """Helper function to compute_gauges()""" X_chunks, A_chunks, dim_man, n_geodesic_nb = inputs - gauges, Sigma = tangent_frames(X_chunks[i], A_chunks[i], dim_man, n_geodesic_nb) return gauges, Sigma -def compute_connections(data, gauges, n_workers=1): - r"""Find smallest rotations R between gauges pairs. It is assumed that the first +def compute_connections(data, gauges, processes=1): + """Find smallest rotations R between gauges pairs. It is assumed that the first row of edge_index is what we want to align to, i.e., gauges(i) = gauges(j)@R[i,j].T + R[i,j] is optimal rotation that minimises ||X - RY||_F computed by SVD: + X, Y = gauges[i].T, gauges[j].T + U, _, Vt = scipy.linalg.svd(X.T@Y) + R[i,j] = U@Vt + Args: data: Pytorch geometric data object gauges (n,d,d matrix): Orthogonal unit vectors for each node + processes: number of CPUs to use Returns: (n*dim,n*dim) matrix of rotation matrices @@ -597,13 +526,15 @@ def compute_connections(data, gauges, n_workers=1): inputs = [gauges, A, dim_man] out = utils.parallel_proc( - _compute_connections, range(n), inputs, processes=n_workers, desc="Computing connections..." + _compute_connections, range(n), inputs, processes=processes, + desc="\n---- Computing connections..." ) return utils.to_block_diag(out) def _compute_connections(inputs, i): + """helper function to compute_connections()""" gauges_chunks, A_chunks, dim_man = inputs R = connections(gauges_chunks[i], A_chunks[i], dim_man) @@ -614,66 +545,6 @@ def _compute_connections(inputs, i): return torch.sparse_coo_tensor(edge_index, R.flatten(), dtype=torch.float32).coalesce() -def scalar_diffusion(x, t, method="matrix_exp", par=None): - """Scalar diffusion.""" - if len(x.shape) == 1: - x = x.unsqueeze(1) - - if method == "matrix_exp": - if par.is_sparse: - par = par.to_dense() - return torch.matrix_exp(-t * par.to_dense()).mm(x) - - if method == "spectral": - assert ( - isinstance(par, (list, tuple)) and len(par) == 2 - ), "For spectral method, par must be a tuple of \ - eigenvalues, eigenvectors!" - evals, evecs = par - - # Transform to spectral - x_spec = torch.mm(evecs.T, x) - - # Diffuse - diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * t.unsqueeze(0)) - x_diffuse_spec = diffusion_coefs * x_spec - - # Transform back to per-vertex - return evecs.mm(x_diffuse_spec) - - raise NotImplementedError - - -def vector_diffusion(x, t, method="spectral", Lc=None): - """Vector diffusion.""" - n, d = x.shape[0], x.shape[1] - - if method == "spectral": - assert len(Lc) == 2, "Lc must be a tuple of eigenvalues, eigenvectors!" - nd = Lc[0].shape[0] - else: - nd = Lc.shape[0] - - assert ( - n * d % nd - ) == 0, "Data dimension must be an integer multiple of the dimensions \ - of the connection Laplacian!" - - # vector diffusion with connection Laplacian - out = x.view(nd, -1) - out = scalar_diffusion(out, t, method, Lc) - out = out.view(x.shape) - - # if normalise: - # assert par['L'] is not None, 'Need Laplacian for normalised diffusion!' - # x_abs = x.norm(dim=-1, p=2, keepdim=True) - # out_abs = scalar_diffusion(x_abs, t, method, par['L']) - # ind = scalar_diffusion(torch.ones(x.shape[0],1), t, method, par['L']) - # out = out*out_abs/(ind*out.norm(dim=-1, p=2, keepdim=True)) - - return out - - def compute_eigendecomposition(A, k=None, eps=1e-8): """Eigendecomposition of a square matrix A. @@ -688,11 +559,13 @@ def compute_eigendecomposition(A, k=None, eps=1e-8): """ if A is None: return None - - if k is not None and k >= A.shape[0]: - k = None - - # Compute the eigenbasis + + if k is None: + A = A.to_dense() + else: + indices, values, size = A.indices(), A.values(), A.size() + A = sp.coo_array((values, (indices[0], indices[1])), shape=size) + failcount = 0 while True: try: @@ -712,6 +585,6 @@ def compute_eigendecomposition(A, k=None, eps=1e-8): raise ValueError("failed to compute eigendecomp") from e failcount += 1 print("--- decomp failed; adding eps ===> count: " + str(failcount)) - A += torch.eye(A.shape[0]) * (eps * 10 ** (failcount - 1)) + A += sp.eye(A.shape[0]) * (eps * 10 ** (failcount - 1)) return evals, evecs diff --git a/MARBLE/layers.py b/MARBLE/layers.py index 361d73c5..2111081c 100644 --- a/MARBLE/layers.py +++ b/MARBLE/layers.py @@ -3,7 +3,7 @@ from torch import nn from torch_geometric.nn.conv import MessagePassing -from MARBLE import geometry as g +from MARBLE import smoothing as s class Diffusion(nn.Module): @@ -20,8 +20,7 @@ def forward(self, x, L, Lc=None, method="spectral"): if method == "spectral": assert ( len(L) == 2 - ), "L must be a matrix or a pair of eigenvalues \ - and eigenvectors" + ), "L must be a matrix or a pair of eigenvalues and eigenvectors" # making sure diffusion times are positive with torch.no_grad(): @@ -30,9 +29,9 @@ def forward(self, x, L, Lc=None, method="spectral"): t = self.diffusion_time if Lc is not None: - out = g.vector_diffusion(x, t, method, Lc) + out = s.vector_diffusion(x, t, Lc, L=L, method=method, normalise=True) else: - out = [g.scalar_diffusion(x_, t, method, L) for x_ in x.T] + out = [s.scalar_diffusion(x_, t, method, L) for x_ in x.T] out = torch.cat(out, axis=1) return out diff --git a/MARBLE/main.py b/MARBLE/main.py index 98bc149e..115a153a 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -217,16 +217,19 @@ def forward(self, data, n_id, adjs=None): n, d = x.shape[0], data.gauges.shape[2] mask = data.mask - # local gauges - if self.params["inner_product_features"]: - x = geometry.map_to_local_gauges(x, data.gauges) - # diffusion if self.params["diffusion"]: - L = data.L.copy() if hasattr(data, "L") else None - Lc = data.Lc.copy() if hasattr(data, "Lc") else None - x = self.diffusion(x, L, Lc=Lc, method="spectral") - + if hasattr(data, "Lc"): + x = geometry.global_to_local_frame(x, data.gauges) + x = self.diffusion(x, data.L, Lc=data.Lc, method="spectral") + x = geometry.global_to_local_frame(x, data.gauges, reverse=True) + else: + x = self.diffusion(x, data.L, method="spectral") + + # local gauges + if self.params["inner_product_features"]: + x = geometry.global_to_local_frame(x, data.gauges) + # restrict to current batch x = x[n_id] mask = mask[n_id] @@ -265,7 +268,7 @@ def forward(self, data, n_id, adjs=None): if self.params["include_positions"]: out = torch.hstack([data.pos[n_id[: last_size[1]]], out]) - + emb = self.enc(out) if self.params["emb_norm"]: # spherical output diff --git a/MARBLE/plotting.py b/MARBLE/plotting.py index e637f31f..270c9af3 100644 --- a/MARBLE/plotting.py +++ b/MARBLE/plotting.py @@ -1,4 +1,5 @@ """Plotting module.""" +import torch import matplotlib import matplotlib.pyplot as plt import networkx as nx @@ -7,7 +8,6 @@ import torch from matplotlib import gridspec from matplotlib.colors import LinearSegmentedColormap -from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d from scipy.spatial import Voronoi from scipy.spatial import voronoi_plot_2d diff --git a/MARBLE/postprocessing.py b/MARBLE/postprocessing.py index 16f4384e..9a465114 100644 --- a/MARBLE/postprocessing.py +++ b/MARBLE/postprocessing.py @@ -79,58 +79,4 @@ def embed_in_2D(data, embed_typ="umap", manifold=None, seed=0): else: data.emb_2D, data.manifold = g.embed(emb, embed_typ=embed_typ, manifold=manifold, seed=seed) - return data - - -# def compare_attractors(data, source_target): -# """Compare attractors.""" -# assert all( -# hasattr(data, attr) for attr in ["emb", "gamma", "clusters", "cdist"] -# ), "It looks like postprocessing has not been run..." - -# s, t = source_target -# slices = data._slice_dict["x"] # pylint: disable=protected-access -# n_slices = len(slices) - 1 -# s_s = range(slices[s], slices[s + 1]) -# s_t = range(slices[t], slices[t + 1]) - -# assert s < n_slices - 2 and t < n_slices - 1, "Source and target must be < number of slices!" -# assert s != t, "Source and target must be different!" - -# _, ax = plt.subplots(1, 3, figsize=(10, 5)) - -# # plot embedding of all points in gray -# plotting.embedding(data.emb, ax=ax[0], alpha=0.05) - -# # get gamma matrix for the given source-target pair -# gammadist = data.gamma[s, t, ...] -# np.fill_diagonal(gammadist, 0.0) - -# # color code source features -# c = gammadist.sum(1) -# cluster_ids = set(data.clusters["labels"][s_s]) -# labels = list(s_s) -# for cid in cluster_ids: -# idx = np.where(cid == data.clusters["labels"][s_s])[0] -# for i in idx: -# labels[i] = c[cid] - -# # plot source features in red -# plotting.embedding(data.emb_2d[s_s], labels=labels, ax=ax[0], alpha=1.0) -# prop_dict = {"style": ">", "lw": 2} -# plotting.trajectories(data.pos[s_s], data.x[s_s], ax=ax[1], node_feature=labels, **prop_dict) -# ax[1].set_title("Before") - -# # color code target features -# c = gammadist.sum(0) -# cluster_ids = set(data.clusters["labels"][s_t]) -# labels = list(s_t) -# for cid in cluster_ids: -# idx = np.where(cid == data.clusters["labels"][s_t])[0] -# for i in idx: -# labels[i] = -c[cid] # negative for blue color - -# # plot target features in blue -# plotting.embedding(data.emb_2d[s_t], labels=labels, ax=ax[0], alpha=1.0) -# plotting.trajectories(data.pos[s_t], data.x[s_t], ax=ax[2], node_feature=labels, **prop_dict) -# ax[2].set_title("After") + return data \ No newline at end of file diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index 836106c4..d6001ce7 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -1,7 +1,6 @@ """Preprocessing module.""" import torch -from torch_geometric.data import Batch -from torch_geometric.data import Data +from torch_geometric.data import Batch, Data from torch_geometric.transforms import RandomNodeSplit from MARBLE import geometry as g @@ -9,15 +8,16 @@ def construct_dataset( - pos, - features, - labels=None, + anchor, + vector, + label=None, mask=None, graph_type="cknn", k=20, delta=1.0, + n_eigenvalues=None, frac_geodesic_nb=1.5, - stop_crit=0.0, + spacing=0.0, number_of_resamples=1, var_explained=0.9, local_gauges=False, @@ -28,9 +28,11 @@ def construct_dataset( pos: matrix with position of points features: matrix with feature values for each point labels: any additional data labels used for plotting only + mask: boolean array, that will be forced to be close (default is None) graph_type: type of nearest-neighbours graph: cknn (default), knn or radius k: number of nearest-neighbours to construct the graph delta: argument for cknn graph construction to decide the radius for each points. + n_eigenvalues: number of eigenvalue/eigenvector pairs to compute (None means all, but this can be slow) frac_geodesic_nb: number of geodesic neighbours to fit the gauges to to map to tangent space k*frac_geodesic_nb stop_crit: stopping criterion for furthest point sampling @@ -40,45 +42,46 @@ def construct_dataset( embedding dimension is > 2 or dim embedding is not dim of manifold) """ - pos = [torch.tensor(p).float() for p in utils.to_list(pos)] - features = [torch.tensor(x).float() for x in utils.to_list(features)] - num_node_features = features[0].shape[1] - - if labels is None: - labels = [torch.arange(len(p)) for p in utils.to_list(pos)] + anchor = [torch.tensor(p).float() for p in utils.to_list(anchor)] + vector = [torch.tensor(x).float() for x in utils.to_list(vector)] + num_node_features = vector[0].shape[1] + + if label is None: + label = [torch.arange(len(p)) for p in utils.to_list(anchor)] else: - labels = [torch.tensor(label).float() for label in utils.to_list(labels)] + label = [torch.tensor(l).float() for l in utils.to_list(label)] if mask is None: - mask = [torch.zeros(len(p), dtype=torch.bool) for p in utils.to_list(pos)] + mask = [torch.zeros(len(p), dtype=torch.bool) for p in utils.to_list(anchor)] else: mask = [torch.tensor(m) for m in utils.to_list(mask)] - if stop_crit == 0.0: + if spacing == 0.0: number_of_resamples = 1 data_list = [] - for i, (p, f, l, m) in enumerate(zip(pos, features, labels, mask)): + for i, (a, v, l, m) in enumerate(zip(anchor, vector, label, mask)): for _ in range(number_of_resamples): # even sampling of points - start_idx = torch.randint(low=0, high=len(p), size=(1,)) - sample_ind, _ = g.furthest_point_sampling(p, stop_crit=stop_crit, start_idx=start_idx) - sample_ind, _ = torch.sort(sample_ind) # this will make postprocessing easier - p_, f_, l_, m_ = p[sample_ind], f[sample_ind], l[sample_ind], m[sample_ind] + start_idx = torch.randint(low=0, high=len(a), size=(1,)) + sample_ind, _ = g.furthest_point_sampling(a, spacing=spacing, start_idx=start_idx) + sample_ind, _ = torch.sort(sample_ind) #this will make postprocessing easier + a_, v_, l_, m_ = a[sample_ind], v[sample_ind], l[sample_ind], m[sample_ind] # fit graph to point cloud - edge_index, edge_weight = g.fit_graph(p_, graph_type=graph_type, par=k, delta=delta) - n = len(p_) + edge_index, edge_weight = g.fit_graph(a_, graph_type=graph_type, par=k, delta=delta) + + # define data object data_ = Data( - pos=p_, - x=f_, - l=l_, + pos=a_, + x=v_, + label=l_, mask=m_, edge_index=edge_index, edge_weight=edge_weight, - num_nodes=n, + num_nodes=len(a_), num_node_features=num_node_features, - y=torch.ones(n, dtype=int) * i, + y=torch.ones(len(a_), dtype=int) * i, sample_ind=sample_ind, ) @@ -96,14 +99,14 @@ def construct_dataset( return _compute_geometric_objects( batch, local_gauges=local_gauges, - frac_geodesic_nb=frac_geodesic_nb, + n_geodesic_nb=k*frac_geodesic_nb, var_explained=var_explained, ) def _compute_geometric_objects( data, - frac_geodesic_nb=2.0, + n_geodesic_nb=10, var_explained=0.9, local_gauges=False, ): @@ -113,8 +116,7 @@ def _compute_geometric_objects( Args: data: pytorch geometric data object - frac_geodesic_nb: fraction of geodesic neighbours relative to neighbours - to fit the tangent spaces to + n_geodesic_nb: number of geodesic neighbours to fit the tangent spaces to var_explained: fraction of variance explained by the local gauges local_gauges: whether to use local or global gauges @@ -146,7 +148,7 @@ def _compute_geometric_objects( if local_gauges: try: - gauges, Sigma = g.compute_gauges(data, n_geodesic_nb=frac_geodesic_nb) + gauges, Sigma = g.compute_gauges(data, n_geodesic_nb=n_geodesic_nb) except Exception as exc: raise Exception( "\nCould not compute gauges (possibly data is too sparse or the \ @@ -159,7 +161,7 @@ def _compute_geometric_objects( if local_gauges: data.dim_man = g.manifold_dimension(Sigma, frac_explained=var_explained) - print(f"\n---- Manifold dimension: {data.dim_man}") + print(f"---- Manifold dimension: {data.dim_man}") gauges = gauges[:, :, : data.dim_man] R = g.compute_connections(data, gauges) @@ -176,7 +178,7 @@ def _compute_geometric_objects( kernels = g.gradient_op(data.pos, data.edge_index, gauges) Lc = None - print("\n---- Computing eigendecomposition ... ", end="") + # print("\n---- Computing eigendecomposition ... ", end="") L = g.compute_eigendecomposition(L) Lc = g.compute_eigendecomposition(Lc) diff --git a/MARBLE/smoothing.py b/MARBLE/smoothing.py new file mode 100644 index 00000000..a424cec6 --- /dev/null +++ b/MARBLE/smoothing.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import torch + +def scalar_diffusion(x, t, method="matrix_exp", par=None): + """Scalar diffusion.""" + if len(x.shape) == 1: + x = x.unsqueeze(1) + + if method == "matrix_exp": + if par.is_sparse: + par = par.to_dense() + return torch.matrix_exp(-t * par.to_dense()).mm(x) + + if method == "spectral": + assert ( + isinstance(par, (list, tuple)) and len(par) == 2 + ), "For spectral method, par must be a tuple of \ + eigenvalues, eigenvectors!" + evals, evecs = par + + # Transform to spectral + x_spec = torch.mm(evecs.T, x) + + # Diffuse + diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * t.unsqueeze(0)) + x_diffuse_spec = diffusion_coefs * x_spec + + # Transform back to per-vertex + return evecs.mm(x_diffuse_spec) + + raise NotImplementedError + + +def vector_diffusion(x, t, Lc, L=None, method="spectral", normalise=True): + """Vector diffusion.""" + n, d = x.shape[0], x.shape[1] + + if method == "spectral": + assert len(Lc) == 2, "Lc must be a tuple of eigenvalues, eigenvectors!" + nd = Lc[1].shape[0] + else: + nd = Lc.shape[0] + + assert ( + n * d % nd + ) == 0, "Data dimension must be an integer multiple of the dimensions \ + of the connection Laplacian!" + + # vector diffusion with connection Laplacian + out = x.view(nd, -1) + out = scalar_diffusion(out, t, method, Lc) + out = out.view(x.shape) + + if normalise: + assert L is not None, 'Need Laplacian for normalised diffusion!' + x_abs = x.norm(dim=-1, p=2, keepdim=True) + out_abs = scalar_diffusion(x_abs, t, method, L) + ind = scalar_diffusion(torch.ones(x.shape[0],1), t, method, L) + out = out*out_abs/(ind*out.norm(dim=-1, p=2, keepdim=True)) + + return out \ No newline at end of file diff --git a/MARBLE/utils.py b/MARBLE/utils.py index ee50506f..d13f1aed 100644 --- a/MARBLE/utils.py +++ b/MARBLE/utils.py @@ -59,12 +59,18 @@ def move_to_gpu(model, data, adjs=None): data.mask = data.mask.to(device) if hasattr(data, "L"): - data.L = [_l.to(device) for _l in data.L] + if len(data.L)==2: + data.L = [_l.to(device) for _l in data.L] + else: + data.L = data.L.to(device) else: data.L = None if hasattr(data, "Lc"): - data.Lc = [_l.to(device) for _l in data.Lc] + if len(data.Lc)==2: + data.Lc = [_l.to(device) for _l in data.Lc] + else: + data.Lc = data.Lc.to(device) else: data.Lc = None @@ -261,16 +267,10 @@ def restrict_to_batch(sp_tensor, idx): raise NotImplementedError -def standardise(X, zero_mean=True, norm="std"): +def standardize(X): """Standarsise data row-wise""" - if zero_mean: - X -= X.mean(axis=0, keepdims=True) - elif norm == "std": - X /= X.std(axis=0, keepdims=True) - elif norm == "max": - X /= abs(X).max(axis=0, keepdims=True) - else: - raise NotImplementedError + mean = X.mean(axis=0, keepdims=True) + std = X.std(axis=0, keepdims=True) - return X + return (X - mean) / std diff --git a/examples/toy_examples/ex_vector_field_curved_surface.py b/examples/toy_examples/ex_vector_field_curved_surface.py index 08d7bfe0..c45c1b5e 100644 --- a/examples/toy_examples/ex_vector_field_curved_surface.py +++ b/examples/toy_examples/ex_vector_field_curved_surface.py @@ -52,7 +52,7 @@ def main(): # construct PyG data object data = preprocessing.construct_dataset( - x, y, graph_type="cknn", k=10, n_geodesic_nb=20, local_gauges=True # use local gauges + x, y, graph_type="cknn", k=10, local_gauges=True # use local gauges ) # train model