From 5cf781147278e8b0879b7e9038b02f27167ab124 Mon Sep 17 00:00:00 2001 From: arnaudon Date: Tue, 28 Nov 2023 15:10:57 +0100 Subject: [PATCH 1/6] format --- MARBLE/geometry.py | 41 ++++++++++-------- MARBLE/layers.py | 2 +- MARBLE/main.py | 81 ++++++++++++++++++----------------- MARBLE/plotting.py | 92 +++++++++++++++++++++++----------------- MARBLE/postprocessing.py | 7 ++- MARBLE/preprocessing.py | 12 +++--- 6 files changed, 129 insertions(+), 106 deletions(-) diff --git a/MARBLE/geometry.py b/MARBLE/geometry.py index bcc6bb92..157cf30c 100644 --- a/MARBLE/geometry.py +++ b/MARBLE/geometry.py @@ -5,12 +5,16 @@ import torch import torch_geometric.utils as PyGu import umap -from sklearn.cluster import KMeans, MeanShift +from sklearn.cluster import KMeans +from sklearn.cluster import MeanShift from sklearn.decomposition import PCA -from sklearn.manifold import TSNE, Isomap, MDS +from sklearn.manifold import MDS +from sklearn.manifold import TSNE +from sklearn.manifold import Isomap from sklearn.metrics import pairwise_distances from sklearn.preprocessing import StandardScaler -from torch_geometric.nn import knn_graph, radius_graph +from torch_geometric.nn import knn_graph +from torch_geometric.nn import radius_graph from torch_scatter import scatter_add from ptu_dijkstra import connections, tangent_frames # isort:skip @@ -18,6 +22,7 @@ from MARBLE.lib.cknn import cknneighbors_graph # isort:skip from MARBLE import utils # isort:skip + def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0): """A greedy O(N^2) algorithm to do furthest points sampling @@ -42,7 +47,7 @@ def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0): perm[0] = start_idx lambdas = torch.zeros(n) ds = D[start_idx, :].flatten() - for i in range(1,n): + for i in range(1, n): idx = torch.argmax(ds) perm[i] = idx lambdas[i] = ds[idx] @@ -54,7 +59,7 @@ def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0): lambdas = lambdas[:i] break - assert len(perm)==len(np.unique(perm)), 'Returned duplicated points' + assert len(perm) == len(np.unique(perm)), "Returned duplicated points" return perm, lambdas @@ -131,10 +136,10 @@ def embed(x, embed_typ="umap", dim_emb=2, manifold=None, seed=0, **kwargs): manifold = PCA(n_components=dim_emb).fit(x) emb = manifold.transform(x) - + elif embed_typ == "Isomap": radius = pairwise_distances(x) - radius = 0.1*(radius.max()-radius.min()) + radius = 0.1 * (radius.max() - radius.min()) if manifold is None: manifold = Isomap(n_components=dim_emb, n_neighbors=None, radius=radius).fit(x) @@ -296,7 +301,7 @@ def gradient_op(pos, edge_index, gauges): _F -= sp.diags(np.array(_F.sum(1)).flatten()) _F = _F.tocoo() K.append(torch.sparse_coo_tensor(np.vstack([_F.row, _F.col]), _F.data.data)) - + return K @@ -450,10 +455,10 @@ def is_connected(edge_index): def compute_laplacian(data, normalization="rw"): """Compute Laplacian.""" edge_index, edge_attr = PyGu.get_laplacian( - data.edge_index, - edge_weight=data.edge_weight, - normalization=normalization, - num_nodes=data.num_nodes + data.edge_index, + edge_weight=data.edge_weight, + normalization=normalization, + num_nodes=data.num_nodes, ) return PyGu.to_dense_adj(edge_index, edge_attr=edge_attr).squeeze() @@ -625,17 +630,17 @@ def scalar_diffusion(x, t, method="matrix_exp", par=None): ), "For spectral method, par must be a tuple of \ eigenvalues, eigenvectors!" evals, evecs = par - + # Transform to spectral x_spec = torch.mm(evecs.T, x) # Diffuse diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * t.unsqueeze(0)) x_diffuse_spec = diffusion_coefs * x_spec - + # Transform back to per-vertex return evecs.mm(x_diffuse_spec) - + raise NotImplementedError @@ -683,7 +688,7 @@ def compute_eigendecomposition(A, k=None, eps=1e-8): """ if A is None: return None - + if k is not None and k >= A.shape[0]: k = None @@ -696,7 +701,7 @@ def compute_eigendecomposition(A, k=None, eps=1e-8): else: evals, evecs = sp.linalg.eigsh(A, k=k, which="SM") evals, evecs = torch.tensor(evals), torch.tensor(evecs) - + evals = torch.clamp(evals, min=0.0) evecs *= np.sqrt(len(evecs)) @@ -709,4 +714,4 @@ def compute_eigendecomposition(A, k=None, eps=1e-8): print("--- decomp failed; adding eps ===> count: " + str(failcount)) A += torch.eye(A.shape[0]) * (eps * 10 ** (failcount - 1)) - return evals, evecs \ No newline at end of file + return evals, evecs diff --git a/MARBLE/layers.py b/MARBLE/layers.py index 2ce1e8aa..361d73c5 100644 --- a/MARBLE/layers.py +++ b/MARBLE/layers.py @@ -1,11 +1,11 @@ """Layer module.""" import torch from torch import nn -from torch.nn.functional import normalize, relu from torch_geometric.nn.conv import MessagePassing from MARBLE import geometry as g + class Diffusion(nn.Module): """Diffusion with learned t.""" diff --git a/MARBLE/main.py b/MARBLE/main.py index ff882fc0..432229ca 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -1,6 +1,7 @@ """Main network""" import glob import os +import warnings from datetime import datetime from pathlib import Path import yaml @@ -14,7 +15,6 @@ from MARBLE import dataloader, geometry, layers, utils -import warnings class net(nn.Module): """MARBLE neural network. @@ -56,8 +56,8 @@ def __init__(self, data, loadpath=None, params=None, verbose=True): """ super().__init__() - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if loadpath is not None: if Path(loadpath).is_dir(): loadpath = max(glob.glob(f"{loadpath}/best_model*")) @@ -65,7 +65,7 @@ def __init__(self, data, loadpath=None, params=None, verbose=True): else: if params is not None: self.params = params - else: + else: self.params = {} self._epoch = 0 # to resume optimisation @@ -192,20 +192,17 @@ def setup_layers(self): # encoder if not isinstance(self.params["hidden_channels"], list): self.params["hidden_channels"] = [self.params["hidden_channels"]] - + channel_list = ( - [cum_channels] - + self.params["hidden_channels"] - + [self.params["out_channels"]] + [cum_channels] + self.params["hidden_channels"] + [self.params["out_channels"]] ) self.enc = MLP( channel_list=channel_list, dropout=self.params["dropout"], bias=self.params["bias"], - ) - - + ) + def forward(self, data, n_id, adjs=None): """Forward pass. Messages are passed to a set target nodes (current batch) from source @@ -226,7 +223,7 @@ def forward(self, data, n_id, adjs=None): L = data.L.copy() if hasattr(data, "L") else None Lc = data.Lc.copy() if hasattr(data, "Lc") else None x = self.diffusion(x, L, Lc=Lc, method="spectral") - + # restrict to current batch x = x[n_id] mask = mask[n_id] @@ -234,7 +231,7 @@ def forward(self, data, n_id, adjs=None): n_id = utils.expand_index(n_id, d) else: d = 1 - + if self.params["vec_norm"]: x = F.normalize(x, dim=-1, p=2) @@ -245,12 +242,12 @@ def forward(self, data, n_id, adjs=None): out = [] for i, (_, _, size) in enumerate(adjs): kernels = [K[n_id[: size[1] * d], :][:, n_id[: size[0] * d]] for K in data.kernels] - + x = self.grad[i](x, kernels) - + if self.params["vec_norm"]: x = F.normalize(x, dim=-1, p=2) - + out.append(x) # take target nodes @@ -265,16 +262,16 @@ def forward(self, data, n_id, adjs=None): if self.params["include_positions"]: out = torch.hstack( [data.pos[n_id[: size[1]]], out] # pylint: disable=undefined-loop-variable - ) - + ) + emb = self.enc(out) - - if self.params['emb_norm']: # spherical output - emb = F.normalize(emb) - + + if self.params["emb_norm"]: # spherical output + emb = F.normalize(emb) + return emb, mask[: size[1]] - - def evaluate(self, data): + + def evaluate(self, data): warnings.warn("MARBLE.evaluate() is deprecated. Use MARBLE.transform() instead.") self.transform(data) @@ -320,7 +317,7 @@ def batch_loss(self, data, loader, train=False, verbose=False, optimizer=None): for batch in tqdm(loader, disable=not verbose): _, n_id, adjs = batch adjs = [adj.to(data.x.device) for adj in utils.to_list(adjs)] - + emb, mask = self.forward(data, n_id, adjs) loss = self.loss(emb, mask) cum_loss += float(loss) @@ -333,7 +330,7 @@ def batch_loss(self, data, loader, train=False, verbose=False, optimizer=None): self.eval() return cum_loss / len(loader), optimizer - + def run_training(self, data, outdir=None, verbose=False): warnings.warn("MARBLE.run_training() is deprecated. Use MARBLE.fit() instead.") @@ -368,7 +365,7 @@ def fit(self, data, outdir=None, verbose=False): scheduler = opt.lr_scheduler.ReduceLROnPlateau(optimizer) best_loss = -1 - self.losses = {'train_loss': [], 'val_loss': [], 'test_loss': []} + self.losses = {"train_loss": [], "val_loss": [], "test_loss": []} for epoch in range( self.params.get("epoch", 0), self.params.get("epoch", 0) + self.params["epochs"] ): @@ -386,18 +383,20 @@ def fit(self, data, outdir=None, verbose=False): ) if best_loss == -1 or (val_loss < best_loss): - outdir = self.save_model(optimizer, self.losses, outdir=outdir, best=True, timestamp=time) + outdir = self.save_model( + optimizer, self.losses, outdir=outdir, best=True, timestamp=time + ) best_loss = val_loss print(" *", end="") - - self.losses['train_loss'].append(train_loss) - self.losses['val_loss'].append(val_loss) + + self.losses["train_loss"].append(train_loss) + self.losses["val_loss"].append(val_loss) test_loss, _ = self.batch_loss(data, test_loader) print(f"\nFinal test loss: {test_loss:.4f}") - self.losses['test_loss'].append(test_loss) - + self.losses["test_loss"].append(test_loss) + self.save_model(optimizer, self.losses, outdir=outdir, best=False, timestamp=time) self.load_model(os.path.join(outdir, f"best_model_{time}.pth")) @@ -407,11 +406,13 @@ def load_model(self, loadpath): Args: loadpath: directory with models to load best model, or specific model path """ - checkpoint = torch.load(loadpath, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + checkpoint = torch.load( + loadpath, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) self._epoch = checkpoint["epoch"] self.load_state_dict(checkpoint["model_state_dict"]) self.optimizer_state_dict = checkpoint["optimizer_state_dict"] - self.losses = checkpoint['losses'] + self.losses = checkpoint["losses"] def save_model(self, optimizer, losses, outdir=None, best=False, timestamp=""): """Save model.""" @@ -448,16 +449,16 @@ def save_model(self, optimizer, losses, outdir=None, best=False, timestamp=""): class loss_fun(nn.Module): """Loss function.""" - + def forward(self, out, mask=None): """forward.""" z, z_pos, z_neg = out.split(out.size(0) // 3, dim=0) pos_loss = F.logsigmoid((z * z_pos).sum(-1)).mean() neg_loss = F.logsigmoid(-(z * z_neg).sum(-1)).mean() - - coagulation_loss = 0. + + coagulation_loss = 0.0 if mask is not None: z_mask = out[mask] - coagulation_loss = (z_mask-z_mask.mean(dim=0)).norm(dim=1).sum() + coagulation_loss = (z_mask - z_mask.mean(dim=0)).norm(dim=1).sum() - return -pos_loss -neg_loss + torch.sigmoid(coagulation_loss) - 0.5 \ No newline at end of file + return -pos_loss - neg_loss + torch.sigmoid(coagulation_loss) - 0.5 diff --git a/MARBLE/plotting.py b/MARBLE/plotting.py index 37977c13..7ad784f2 100644 --- a/MARBLE/plotting.py +++ b/MARBLE/plotting.py @@ -8,10 +8,10 @@ import numpy as np import seaborn as sns from matplotlib import gridspec -from matplotlib.patches import FancyArrowPatch from matplotlib.colors import LinearSegmentedColormap -from mpl_toolkits.mplot3d.art3d import Line3DCollection +from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d +from mpl_toolkits.mplot3d.art3d import Line3DCollection from scipy.spatial import Voronoi from scipy.spatial import voronoi_plot_2d from torch_geometric.utils.convert import to_networkx @@ -129,9 +129,7 @@ def histograms(data, titles=None, col=2, figsize=(10, 10)): col: int for number of columns to plot figsize: tuple of figure dimensions """ - assert hasattr( - data, "clusters" - ), "No clusters found. First, run postprocessing.cluster(data)!" + assert hasattr(data, "clusters"), "No clusters found. First, run postprocessing.cluster(data)!" labels, s = data.clusters["labels"], data.clusters["slices"] n_slices = len(s) - 1 @@ -171,9 +169,9 @@ def embedding( clusters_visible=False, cmap="coolwarm", plot_trajectories=False, - style='o', + style="o", lw=1, - time_gradient=False + time_gradient=False, ): """Plot embeddings. @@ -199,54 +197,74 @@ def embedding( if labels is None: labels = np.ones(emb.shape[0]) - + if mask is None: mask = np.ones(len(emb), dtype=bool) labels = labels[mask] types = sorted(set(labels)) - + color, cbar = set_colors(types, cmap) - + if titles is not None: assert len(titles) == len(types) for i, typ in enumerate(types): title = titles[i] if titles is not None else str(typ) c_ = color[i] - emb_ = emb[mask*(labels == typ)] - + emb_ = emb[mask * (labels == typ)] + if isinstance(data, np.ndarray) or torch.is_tensor(data): - print('You need to pass a data object to plot trajectories!') + print("You need to pass a data object to plot trajectories!") plot_trajectories = False - + if plot_trajectories: - l_ = data.l[mask*(labels == typ)] + l_ = data.l[mask * (labels == typ)] if len(l_) == 0: continue - end = np.where(np.diff(l_)<0)[0]+1 + end = np.where(np.diff(l_) < 0)[0] + 1 start = np.hstack([0, end]) end = np.hstack([end, len(emb_)]) cmap = LinearSegmentedColormap.from_list("Custom", [(0, 0, 0), c_], N=max(l_)) - - for i, (s_,e_) in enumerate(zip(start, end)): - t = range(s_,e_) - cgrad = cmap(l_[t]/max(l_)) - if style=='-': + + for i, (s_, e_) in enumerate(zip(start, end)): + t = range(s_, e_) + cgrad = cmap(l_[t] / max(l_)) + if style == "-": if time_gradient: - trajectories(emb_[t], style='-', ax=ax, ms=s, node_feature=cgrad, alpha=alpha, lw=lw) + trajectories( + emb_[t], style="-", ax=ax, ms=s, node_feature=cgrad, alpha=alpha, lw=lw + ) else: - trajectories(emb_[t], style='-', ax=ax, ms=s, node_feature=[c_]*len(t), alpha=alpha, lw=lw) - elif style=='o': + trajectories( + emb_[t], + style="-", + ax=ax, + ms=s, + node_feature=[c_] * len(t), + alpha=alpha, + lw=lw, + ) + elif style == "o": if dim == 2: ax.scatter(emb_[t, 0], emb_[t, 1], c=cgrad, alpha=alpha, s=s, label=title) elif dim == 3: - ax.scatter(emb_[t, 0], emb_[t, 1], emb_[t, 2], c=cgrad, alpha=alpha, s=s, label=title) + ax.scatter( + emb_[t, 0], + emb_[t, 1], + emb_[t, 2], + c=cgrad, + alpha=alpha, + s=s, + label=title, + ) else: if dim == 2: ax.scatter(emb_[:, 0], emb_[:, 1], color=c_, alpha=alpha, s=s, label=title) elif dim == 3: - ax.scatter(emb_[:, 0], emb_[:, 1], emb_[:, 2], color=c_, alpha=alpha, s=s, label=title) + ax.scatter( + emb_[:, 0], emb_[:, 1], emb_[:, 2], color=c_, alpha=alpha, s=s, label=title + ) if dim == 2: if hasattr(data, "clusters") and clusters_visible: @@ -267,13 +285,13 @@ def embedding( def losses(model): """Model losses""" - plt.plot(model.losses['train_loss'], label='Training loss') - plt.plot(model.losses['val_loss'], label='Validation loss') - plt.xlabel('Epochs') - plt.ylabel('MSE loss') + plt.plot(model.losses["train_loss"], label="Training loss") + plt.plot(model.losses["val_loss"], label="Validation loss") + plt.xlabel("Epochs") + plt.ylabel("MSE loss") plt.legend() - - + + def voronoi(clusters, ax): """Voronoi tesselation of clusters""" vor = Voronoi(clusters["centroids"]) @@ -303,9 +321,7 @@ def neighbourhoods( plot_graph: if True, then plot the underlying graph. """ - assert hasattr( - data, "clusters" - ), "No clusters found. First, run postprocessing.cluster(data)!" + assert hasattr(data, "clusters"), "No clusters found. First, run postprocessing.cluster(data)!" vector = data.x.shape[1] > 1 clusters = data.clusters @@ -451,7 +467,7 @@ def graph( return ax -#def time_series(T, X, style="o", node_feature=None, figsize=(10, 5), lw=1, ms=5): +# def time_series(T, X, style="o", node_feature=None, figsize=(10, 5), lw=1, ms=5): # """Plot time series. # Args: @@ -533,7 +549,7 @@ def trajectories( _, ax = create_axis(dim) c = set_colors(node_feature)[0] - + if dim == 2: if "o" in style: ax.scatter(X[:, 0], X[:, 1], c=c, s=ms, alpha=alpha) @@ -723,4 +739,4 @@ def set_colors(color, cmap="coolwarm"): cbar = plt.cm.ScalarMappable(norm=norm, cmap=cmap) - return colors, cbar \ No newline at end of file + return colors, cbar diff --git a/MARBLE/postprocessing.py b/MARBLE/postprocessing.py index 8084635f..f8d9e449 100644 --- a/MARBLE/postprocessing.py +++ b/MARBLE/postprocessing.py @@ -5,17 +5,16 @@ def cluster(data, cluster_typ="kmeans", n_clusters=15, seed=0): - clusters = g.cluster(data.emb, cluster_typ, n_clusters, seed) clusters = g.relabel_by_proximity(clusters) - + clusters["slices"] = data._slice_dict["x"] # pylint: disable=protected-access if data.number_of_resamples > 1: clusters["slices"] = clusters["slices"][:: data.number_of_resamples] - + data.clusters = clusters - + return data diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index ec38d24a..4caf7662 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -3,6 +3,7 @@ from torch_geometric.data import Batch from torch_geometric.data import Data from torch_geometric.transforms import RandomNodeSplit + from MARBLE import geometry as g from MARBLE import utils @@ -42,12 +43,12 @@ def construct_dataset( pos = [torch.tensor(p).float() for p in utils.to_list(pos)] features = [torch.tensor(x).float() for x in utils.to_list(features)] num_node_features = features[0].shape[1] - + if labels is None: labels = [torch.arange(len(p)) for p in utils.to_list(pos)] else: labels = [torch.tensor(l).float() for l in utils.to_list(labels)] - + if mask is None: mask = [torch.zeros(len(p), dtype=torch.bool) for p in utils.to_list(pos)] else: @@ -62,7 +63,7 @@ def construct_dataset( # even sampling of points start_idx = torch.randint(low=0, high=len(p), size=(1,)) sample_ind, _ = g.furthest_point_sampling(p, stop_crit=stop_crit, start_idx=start_idx) - sample_ind, _ = torch.sort(sample_ind) #this will make postprocessing easier + sample_ind, _ = torch.sort(sample_ind) # this will make postprocessing easier p_, f_, l_, m_ = p[sample_ind], f[sample_ind], l[sample_ind], m[sample_ind] # fit graph to point cloud @@ -100,7 +101,8 @@ def construct_dataset( ) -def _compute_geometric_objects(data, +def _compute_geometric_objects( + data, frac_geodesic_nb=2.0, var_explained=0.9, local_gauges=False, @@ -158,7 +160,7 @@ def _compute_geometric_objects(data, data.dim_man = g.manifold_dimension(Sigma, frac_explained=var_explained) print(f"\n---- Manifold dimension: {data.dim_man}") - gauges = gauges[:, :, :data.dim_man] + gauges = gauges[:, :, : data.dim_man] R = g.compute_connections(data, gauges) print("\n---- Computing kernels ... ", end="") From 5756d605ccfded1dc26cff12aedb667bd68c6a9f Mon Sep 17 00:00:00 2001 From: arnaudon Date: Wed, 29 Nov 2023 17:29:48 +0100 Subject: [PATCH 2/6] more --- .pylintrc | 2 +- MARBLE/geometry.py | 2 +- MARBLE/main.py | 15 ++++++++------- MARBLE/plotting.py | 1 - MARBLE/postprocessing.py | 1 + MARBLE/preprocessing.py | 2 +- tests/test_vector_diffusion.py | 4 +++- 7 files changed, 15 insertions(+), 12 deletions(-) diff --git a/.pylintrc b/.pylintrc index 35e0a277..847b5136 100644 --- a/.pylintrc +++ b/.pylintrc @@ -19,7 +19,7 @@ max-locals=15 # Maximum number of return / yield for function / method body max-returns=6 # Maximum number of branch for function / method body -max-branches=20 +max-branches=30 # Maximum number of statements in function / method body max-statements=65 # Maximum number of parents for a class (see R0901). diff --git a/MARBLE/geometry.py b/MARBLE/geometry.py index 157cf30c..e3d74bb0 100644 --- a/MARBLE/geometry.py +++ b/MARBLE/geometry.py @@ -697,7 +697,7 @@ def compute_eigendecomposition(A, k=None, eps=1e-8): while True: try: if k is None: - evals, evecs = torch.linalg.eigh(A) + evals, evecs = torch.linalg.eigh(A) # pylint: disable=not-callable else: evals, evecs = sp.linalg.eigsh(A, k=k, which="SM") evals, evecs = torch.tensor(evals), torch.tensor(evecs) diff --git a/MARBLE/main.py b/MARBLE/main.py index 432229ca..c3d38705 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -250,8 +250,9 @@ def forward(self, data, n_id, adjs=None): out.append(x) + last_size = adjs[-1][2] # take target nodes - out = [o[: size[1]] for o in out] # pylint: disable=undefined-loop-variable + out = [o[: last_size[1]] for o in out] # inner products if self.params["inner_product_features"]: @@ -260,18 +261,17 @@ def forward(self, data, n_id, adjs=None): out = torch.cat(out, axis=1) if self.params["include_positions"]: - out = torch.hstack( - [data.pos[n_id[: size[1]]], out] # pylint: disable=undefined-loop-variable - ) + out = torch.hstack([data.pos[n_id[: last_size[1]]], out]) emb = self.enc(out) if self.params["emb_norm"]: # spherical output emb = F.normalize(emb) - return emb, mask[: size[1]] + return emb, mask[: last_size[1]] def evaluate(self, data): + """Evaluate.""" warnings.warn("MARBLE.evaluate() is deprecated. Use MARBLE.transform() instead.") self.transform(data) @@ -332,6 +332,7 @@ def batch_loss(self, data, loader, train=False, verbose=False, optimizer=None): return cum_loss / len(loader), optimizer def run_training(self, data, outdir=None, verbose=False): + """Run training.""" warnings.warn("MARBLE.run_training() is deprecated. Use MARBLE.fit() instead.") self.fit(data, outdir=outdir, verbose=verbose) @@ -453,8 +454,8 @@ class loss_fun(nn.Module): def forward(self, out, mask=None): """forward.""" z, z_pos, z_neg = out.split(out.size(0) // 3, dim=0) - pos_loss = F.logsigmoid((z * z_pos).sum(-1)).mean() - neg_loss = F.logsigmoid(-(z * z_neg).sum(-1)).mean() + pos_loss = F.logsigmoid((z * z_pos).sum(-1)).mean() # pylint: disable=not-callable + neg_loss = F.logsigmoid(-(z * z_neg).sum(-1)).mean() # pylint: disable=not-callable coagulation_loss = 0.0 if mask is not None: diff --git a/MARBLE/plotting.py b/MARBLE/plotting.py index 7ad784f2..099fe09b 100644 --- a/MARBLE/plotting.py +++ b/MARBLE/plotting.py @@ -11,7 +11,6 @@ from matplotlib.colors import LinearSegmentedColormap from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d -from mpl_toolkits.mplot3d.art3d import Line3DCollection from scipy.spatial import Voronoi from scipy.spatial import voronoi_plot_2d from torch_geometric.utils.convert import to_networkx diff --git a/MARBLE/postprocessing.py b/MARBLE/postprocessing.py index f8d9e449..16f4384e 100644 --- a/MARBLE/postprocessing.py +++ b/MARBLE/postprocessing.py @@ -5,6 +5,7 @@ def cluster(data, cluster_typ="kmeans", n_clusters=15, seed=0): + """Cluster data.""" clusters = g.cluster(data.emb, cluster_typ, n_clusters, seed) clusters = g.relabel_by_proximity(clusters) diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index 4caf7662..5ea3fa86 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -47,7 +47,7 @@ def construct_dataset( if labels is None: labels = [torch.arange(len(p)) for p in utils.to_list(pos)] else: - labels = [torch.tensor(l).float() for l in utils.to_list(labels)] + labels = [torch.tensor(label).float() for label in utils.to_list(labels)] if mask is None: mask = [torch.zeros(len(p), dtype=torch.bool) for p in utils.to_list(pos)] diff --git a/tests/test_vector_diffusion.py b/tests/test_vector_diffusion.py index df20ee66..e71ed62a 100644 --- a/tests/test_vector_diffusion.py +++ b/tests/test_vector_diffusion.py @@ -141,7 +141,9 @@ def test_diffusion_sphere(plot=False): y = f2(x) # construct PyG data object - data = construct_dataset(x, y, graph_type="radius", k=k, n_geodesic_nb=10, var_explained=0.9) + data = construct_dataset( + x, y, graph_type="radius", k=k, frac_geodesic_nb=1.5, var_explained=0.9 + ) L = geometry.compute_laplacian(data) From e7d91c68b577ea20e8784789c21ea71adf314341 Mon Sep 17 00:00:00 2001 From: arnaudon Date: Thu, 30 Nov 2023 10:55:53 +0100 Subject: [PATCH 3/6] test popple --- .github/workflows/run-tox.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/run-tox.yml b/.github/workflows/run-tox.yml index 2ea484fb..445767e3 100644 --- a/.github/workflows/run-tox.yml +++ b/.github/workflows/run-tox.yml @@ -24,7 +24,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - sudo apt-get install -y poppler-utils imagemagick pandoc + sudo apt-get update + sudo apt-get install -y poppler-utils imagemagick pandoc --fix-missing python -m pip install --upgrade pip setuptools pip install tox-gh-actions pandoc - name: Run tox From 5b2a026beab02583211615ccbb1b0d6569d25e51 Mon Sep 17 00:00:00 2001 From: arnaudon Date: Thu, 30 Nov 2023 11:44:25 +0100 Subject: [PATCH 4/6] lint --- MARBLE/main.py | 9 ++++++--- MARBLE/plotting.py | 3 ++- MARBLE/preprocessing.py | 4 ++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/MARBLE/main.py b/MARBLE/main.py index c3d38705..98bc149e 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -4,16 +4,19 @@ import warnings from datetime import datetime from pathlib import Path -import yaml -from tqdm import tqdm import torch import torch.nn.functional as F import torch.optim as opt +import yaml from torch import nn from torch_geometric.nn import MLP +from tqdm import tqdm -from MARBLE import dataloader, geometry, layers, utils +from MARBLE import dataloader +from MARBLE import geometry +from MARBLE import layers +from MARBLE import utils class net(nn.Module): diff --git a/MARBLE/plotting.py b/MARBLE/plotting.py index 099fe09b..0b1f464d 100644 --- a/MARBLE/plotting.py +++ b/MARBLE/plotting.py @@ -1,12 +1,13 @@ """Plotting module.""" import os from pathlib import Path -import torch + import matplotlib import matplotlib.pyplot as plt import networkx as nx import numpy as np import seaborn as sns +import torch from matplotlib import gridspec from matplotlib.colors import LinearSegmentedColormap from matplotlib.patches import FancyArrowPatch diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index 5ea3fa86..1fbd5060 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -125,7 +125,7 @@ def _compute_geometric_objects( gauges (nxdxd): local gauges at all points par (dict): updated dictionary of parameters local_gauges: whether to use local gauges - + """ n, dim_emb = data.pos.shape dim_signal = data.x.shape[1] @@ -184,4 +184,4 @@ def _compute_geometric_objects( ] data.L, data.Lc, data.gauges, data.local_gauges = L, Lc, gauges, local_gauges - return data \ No newline at end of file + return data From 1891ba048e8ab4d37c18055e31f48c16c534cde8 Mon Sep 17 00:00:00 2001 From: arnaudon Date: Thu, 30 Nov 2023 11:45:49 +0100 Subject: [PATCH 5/6] fix --- MARBLE/preprocessing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index 1fbd5060..836106c4 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -113,7 +113,8 @@ def _compute_geometric_objects( Args: data: pytorch geometric data object - frac_geodesic_nb: fraction of geodesic neighbours relative to neighbours to fit the tangent spaces to + frac_geodesic_nb: fraction of geodesic neighbours relative to neighbours + to fit the tangent spaces to var_explained: fraction of variance explained by the local gauges local_gauges: whether to use local or global gauges From 3dedb01f9f4299f19a353152c6e05dd1dbcebd44 Mon Sep 17 00:00:00 2001 From: arnaudon Date: Thu, 30 Nov 2023 12:56:43 +0100 Subject: [PATCH 6/6] fix lint --- MARBLE/geometry.py | 6 +++--- MARBLE/plotting.py | 3 --- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/MARBLE/geometry.py b/MARBLE/geometry.py index e3d74bb0..f30b6d6d 100644 --- a/MARBLE/geometry.py +++ b/MARBLE/geometry.py @@ -465,7 +465,7 @@ def compute_laplacian(data, normalization="rw"): def compute_connection_laplacian(data, R, normalization="rw"): - """Connection Laplacian + r"""Connection Laplacian Args: data: Pytorch geometric data object. @@ -520,7 +520,7 @@ def compute_connection_laplacian(data, R, normalization="rw"): def compute_gauges(data, dim_man=None, n_geodesic_nb=10, n_workers=1): - """Orthonormal gauges for the tangent space at each node, and connection + r"""Orthonormal gauges for the tangent space at each node, and connection matrices between each pair of adjacent nodes. R is a block matrix, where the row index is the gauge we want to align to, @@ -573,7 +573,7 @@ def _compute_gauges(inputs, i): def compute_connections(data, gauges, n_workers=1): - """Find smallest rotations R between gauges pairs. It is assumed that the first + r"""Find smallest rotations R between gauges pairs. It is assumed that the first row of edge_index is what we want to align to, i.e., gauges(i) = gauges(j)@R[i,j].T diff --git a/MARBLE/plotting.py b/MARBLE/plotting.py index 0b1f464d..e637f31f 100644 --- a/MARBLE/plotting.py +++ b/MARBLE/plotting.py @@ -1,7 +1,4 @@ """Plotting module.""" -import os -from pathlib import Path - import matplotlib import matplotlib.pyplot as plt import networkx as nx