Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some small changes #23

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions MARBLE/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def furthest_point_sampling(x, N=None, spacing=0.0, start_idx=0):
break

assert len(perm) == len(np.unique(perm)), "Returned duplicated points"

return perm, lambdas


Expand All @@ -77,7 +78,7 @@ def cluster(x, cluster_typ="meanshift", n_clusters=15, seed=0):
"""
clusters = {}
if cluster_typ == "kmeans":
kmeans = KMeans(n_clusters=n_clusters, random_state=seed).fit(x)
kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init='auto').fit(x)
clusters["n_clusters"] = n_clusters
clusters["labels"] = kmeans.labels_
clusters["centroids"] = kmeans.cluster_centers_
Expand Down Expand Up @@ -157,7 +158,7 @@ def relabel_by_proximity(clusters):
"""Update clusters labels such that nearby clusters in the embedding get similar labels.

Args:
clusters: sklearn object containing 'centroids', 'n_clusters', 'labels'
clusters: sklearn object containing 'centroids', 'n_clusters', 'labels' as attributes

Returns:
clusters: sklearn object with updated labels
Expand Down Expand Up @@ -186,7 +187,7 @@ def compute_distribution_distances(clusters=None, data=None, slices=None):
"""Compute the distance between clustered distributions across datasets.

Args:
clusters: sklearn object containing 'centroids', 'slices', 'labels'
clusters: sklearn object containing 'centroids', 'slices', 'labels' as attributes

Returns:
dist: distance matrix
Expand Down
19 changes: 6 additions & 13 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class net(nn.Module):
bias: learn bias parameters in MLP (default=True)
vec_norm: normalise features at each derivative order to unit length (default=False)
emb_norm: normalise MLP output to unit length (default=False)
batch_norm: batch normalisation (default=False)
seed: seed for reproducibility (default=0)
batch_norm: batch normalisation (default=True)
seed: seed for reproducibility
"""

def __init__(self, data, loadpath=None, params=None, verbose=True):
Expand Down Expand Up @@ -123,11 +123,6 @@ def check_parameters(self, data):
if self.params["diffusion"]:
assert hasattr(data, "L"), "No Laplacian found. Compute it in preprocessing()!"

if data.local_gauges:
assert self.params[
"inner_product_features"
], "Local gauges detected, so >>inner_product_features<< must be True"

pars = [
"batch_size",
"epochs",
Expand All @@ -137,7 +132,6 @@ def check_parameters(self, data):
"inner_product_features",
"dim_signal",
"dim_emb",
"dim_man",
"frac_sampled_nb",
"dropout",
"diffusion",
Expand All @@ -148,14 +142,12 @@ def check_parameters(self, data):
"vec_norm",
"emb_norm",
"seed",
"n_sampled_nb",
"include_positions",
"include_self",
"processes",
]

for p in self.params.keys():
assert p in pars, f"Unknown specified parameter {p}!"
for p in pars:
assert p in list(self.params.keys()), f"Parameter {p} is not specified!"

def reset_parameters(self):
"""reset parmaeters."""
Expand Down Expand Up @@ -205,6 +197,7 @@ def setup_layers(self):
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
norm=self.params['batch_norm']
)

def forward(self, data, n_id, adjs=None):
Expand Down
20 changes: 17 additions & 3 deletions MARBLE/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,15 @@ def trajectories(
Args:
X (np array): Positions
V (np array): Velocities
style (string): Plotting style. The default is 'o'
ax (matplotlib axes object): If specificed, it will plot on existing axes. The default is None
style (string): Plotting style. 'o' for scatter plot or '-' for line plot
node_feature: Color lines. The default is None
lw (int): Line width
ms (int): Marker size
scale (float): Scaling of arrows
arrow_spacing (int): How many timesteps apart are the arrows spaced.
axes_visible (bool): Whether to display axes
alpha (float): transparancy of the markers

Returns:
matplotlib axes object.
Expand Down Expand Up @@ -563,7 +568,14 @@ def trajectories(
alpha=alpha,
)
else:
ax.plot(X[:, 0], X[:, 1], c=c, linewidth=lw, markersize=ms, alpha=alpha)
ax.plot(
X[:, 0],
X[:, 1],
c=c,
linewidth=lw,
markersize=ms,
alpha=alpha
)
if ">" in style:
skip = (slice(None, None, arrow_spacing), slice(None))
X, V = X[skip], V[skip]
Expand All @@ -573,7 +585,7 @@ def trajectories(
if "o" in style:
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=c, s=ms, alpha=alpha)
if "-" in style:
if isinstance(c, (list, tuple)):
if isinstance(c, (np.ndarray, list, tuple)):
for i in range(len(X) - 2):
ax.plot(
X[i : i + 2, 0],
Expand All @@ -600,6 +612,8 @@ def trajectories(
skip = (slice(None, None, arrow_spacing), slice(None))
X, V = X[skip], V[skip]
plot_arrows(X, V, ax, c, width=lw, scale=scale)
else:
raise Exception('Data dimension is: {}. It needs to be 2 or 3 to allow plotting.'.format(dim))

set_axes(ax, axes_visible=axes_visible)

Expand Down
6 changes: 3 additions & 3 deletions MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def _compute_geometric_objects(
# disable vector computations if 1) signal is scalar or 2) embedding dimension
# is <= 2. In case 2), either M=R^2 (manifold is whole space) or case 1).
if dim_signal == 1:
print("Signal dimension is 1, so manifold computations are disabled!")
print("\nSignal dimension is 1, so manifold computations are disabled!")
local_gauges = False
if dim_emb <= 2:
print("Embedding dimension <= 2, so manifold computations are disabled!")
print("\nEmbedding dimension <= 2, so manifold computations are disabled!")
local_gauges = False
if dim_emb != dim_signal:
print("Embedding dimension /= signal dimension, so manifold computations are disabled!")
print("\nEmbedding dimension /= signal dimension, so manifold computations are disabled!")

if local_gauges:
try:
Expand Down
Loading