diff --git a/MARBLE/geometry.py b/MARBLE/geometry.py index 662bc66a..3c931256 100644 --- a/MARBLE/geometry.py +++ b/MARBLE/geometry.py @@ -60,6 +60,7 @@ def furthest_point_sampling(x, N=None, spacing=0.0, start_idx=0): break assert len(perm) == len(np.unique(perm)), "Returned duplicated points" + return perm, lambdas @@ -77,7 +78,7 @@ def cluster(x, cluster_typ="meanshift", n_clusters=15, seed=0): """ clusters = {} if cluster_typ == "kmeans": - kmeans = KMeans(n_clusters=n_clusters, random_state=seed).fit(x) + kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init='auto').fit(x) clusters["n_clusters"] = n_clusters clusters["labels"] = kmeans.labels_ clusters["centroids"] = kmeans.cluster_centers_ @@ -157,7 +158,7 @@ def relabel_by_proximity(clusters): """Update clusters labels such that nearby clusters in the embedding get similar labels. Args: - clusters: sklearn object containing 'centroids', 'n_clusters', 'labels' + clusters: sklearn object containing 'centroids', 'n_clusters', 'labels' as attributes Returns: clusters: sklearn object with updated labels @@ -186,7 +187,7 @@ def compute_distribution_distances(clusters=None, data=None, slices=None): """Compute the distance between clustered distributions across datasets. Args: - clusters: sklearn object containing 'centroids', 'slices', 'labels' + clusters: sklearn object containing 'centroids', 'slices', 'labels' as attributes Returns: dist: distance matrix diff --git a/MARBLE/main.py b/MARBLE/main.py index 3820dcef..a734b3eb 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -43,8 +43,8 @@ class net(nn.Module): bias: learn bias parameters in MLP (default=True) vec_norm: normalise features at each derivative order to unit length (default=False) emb_norm: normalise MLP output to unit length (default=False) - batch_norm: batch normalisation (default=False) - seed: seed for reproducibility (default=0) + batch_norm: batch normalisation (default=True) + seed: seed for reproducibility """ def __init__(self, data, loadpath=None, params=None, verbose=True): @@ -123,11 +123,6 @@ def check_parameters(self, data): if self.params["diffusion"]: assert hasattr(data, "L"), "No Laplacian found. Compute it in preprocessing()!" - if data.local_gauges: - assert self.params[ - "inner_product_features" - ], "Local gauges detected, so >>inner_product_features<< must be True" - pars = [ "batch_size", "epochs", @@ -137,7 +132,6 @@ def check_parameters(self, data): "inner_product_features", "dim_signal", "dim_emb", - "dim_man", "frac_sampled_nb", "dropout", "diffusion", @@ -148,14 +142,12 @@ def check_parameters(self, data): "vec_norm", "emb_norm", "seed", - "n_sampled_nb", "include_positions", "include_self", - "processes", ] - - for p in self.params.keys(): - assert p in pars, f"Unknown specified parameter {p}!" + + for p in pars: + assert p in list(self.params.keys()), f"Parameter {p} is not specified!" def reset_parameters(self): """reset parmaeters.""" @@ -205,6 +197,7 @@ def setup_layers(self): channel_list=channel_list, dropout=self.params["dropout"], bias=self.params["bias"], + norm=self.params['batch_norm'] ) def forward(self, data, n_id, adjs=None): diff --git a/MARBLE/plotting.py b/MARBLE/plotting.py index 39894efd..a26fc2a3 100644 --- a/MARBLE/plotting.py +++ b/MARBLE/plotting.py @@ -532,10 +532,15 @@ def trajectories( Args: X (np array): Positions V (np array): Velocities - style (string): Plotting style. The default is 'o' + ax (matplotlib axes object): If specificed, it will plot on existing axes. The default is None + style (string): Plotting style. 'o' for scatter plot or '-' for line plot node_feature: Color lines. The default is None lw (int): Line width ms (int): Marker size + scale (float): Scaling of arrows + arrow_spacing (int): How many timesteps apart are the arrows spaced. + axes_visible (bool): Whether to display axes + alpha (float): transparancy of the markers Returns: matplotlib axes object. @@ -563,7 +568,14 @@ def trajectories( alpha=alpha, ) else: - ax.plot(X[:, 0], X[:, 1], c=c, linewidth=lw, markersize=ms, alpha=alpha) + ax.plot( + X[:, 0], + X[:, 1], + c=c, + linewidth=lw, + markersize=ms, + alpha=alpha + ) if ">" in style: skip = (slice(None, None, arrow_spacing), slice(None)) X, V = X[skip], V[skip] @@ -573,7 +585,7 @@ def trajectories( if "o" in style: ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=c, s=ms, alpha=alpha) if "-" in style: - if isinstance(c, (list, tuple)): + if isinstance(c, (np.ndarray, list, tuple)): for i in range(len(X) - 2): ax.plot( X[i : i + 2, 0], @@ -600,6 +612,8 @@ def trajectories( skip = (slice(None, None, arrow_spacing), slice(None)) X, V = X[skip], V[skip] plot_arrows(X, V, ax, c, width=lw, scale=scale) + else: + raise Exception('Data dimension is: {}. It needs to be 2 or 3 to allow plotting.'.format(dim)) set_axes(ax, axes_visible=axes_visible) diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index d6001ce7..0d072677 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -138,13 +138,13 @@ def _compute_geometric_objects( # disable vector computations if 1) signal is scalar or 2) embedding dimension # is <= 2. In case 2), either M=R^2 (manifold is whole space) or case 1). if dim_signal == 1: - print("Signal dimension is 1, so manifold computations are disabled!") + print("\nSignal dimension is 1, so manifold computations are disabled!") local_gauges = False if dim_emb <= 2: - print("Embedding dimension <= 2, so manifold computations are disabled!") + print("\nEmbedding dimension <= 2, so manifold computations are disabled!") local_gauges = False if dim_emb != dim_signal: - print("Embedding dimension /= signal dimension, so manifold computations are disabled!") + print("\nEmbedding dimension /= signal dimension, so manifold computations are disabled!") if local_gauges: try: