Skip to content

Commit

Permalink
code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Dec 12, 2023
1 parent 6560708 commit 4a98923
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 54 deletions.
26 changes: 16 additions & 10 deletions MARBLE/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def furthest_point_sampling(x, N=None, spacing=0.0, start_idx=0):
break

assert len(perm) == len(np.unique(perm)), "Returned duplicated points"

return perm, lambdas


Expand All @@ -78,7 +78,7 @@ def cluster(x, cluster_typ="meanshift", n_clusters=15, seed=0):
"""
clusters = {}
if cluster_typ == "kmeans":
kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init='auto').fit(x)
kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init="auto").fit(x)
clusters["n_clusters"] = n_clusters
clusters["labels"] = kmeans.labels_
clusters["centroids"] = kmeans.cluster_centers_
Expand Down Expand Up @@ -425,7 +425,7 @@ def compute_connection_laplacian(data, R, normalization="rw"):
d = R.size()[0] // n

# unnormalised (combinatorial) laplacian, to be normalised later
L = compute_laplacian(data, normalization=None)#.to_sparse()
L = compute_laplacian(data, normalization=None) # .to_sparse()

# rearrange into block form (kron(L, ones(d,d)))
edge_index = utils.expand_edge_index(L.indices(), dim=d)
Expand Down Expand Up @@ -478,8 +478,11 @@ def compute_gauges(data, dim_man=None, n_geodesic_nb=10, processes=1):

inputs = [X, A, dim_man, n_geodesic_nb]
out = utils.parallel_proc(
_compute_gauges, range(n), inputs, processes=processes,
desc="\n---- Computing tangent spaces..."
_compute_gauges,
range(n),
inputs,
processes=processes,
desc="\n---- Computing tangent spaces...",
)

gauges, Sigma = zip(*out)
Expand Down Expand Up @@ -527,8 +530,11 @@ def compute_connections(data, gauges, processes=1):

inputs = [gauges, A, dim_man]
out = utils.parallel_proc(
_compute_connections, range(n), inputs, processes=processes,
desc="\n---- Computing connections..."
_compute_connections,
range(n),
inputs,
processes=processes,
desc="\n---- Computing connections...",
)

return utils.to_block_diag(out)
Expand Down Expand Up @@ -560,13 +566,13 @@ def compute_eigendecomposition(A, k=None, eps=1e-8):
"""
if A is None:
return None

if k is None:
A = A.to_dense()
else:
else:
indices, values, size = A.indices(), A.values(), A.size()
A = sp.coo_array((values, (indices[0], indices[1])), shape=size)

failcount = 0
while True:
try:
Expand Down
4 changes: 1 addition & 3 deletions MARBLE/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def __init__(self, tau0=0.0):
def forward(self, x, L, Lc=None, method="spectral"):
"""Forward."""
if method == "spectral":
assert (
len(L) == 2
), "L must be a matrix or a pair of eigenvalues and eigenvectors"
assert len(L) == 2, "L must be a matrix or a pair of eigenvalues and eigenvectors"

# making sure diffusion times are positive
with torch.no_grad():
Expand Down
12 changes: 6 additions & 6 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def check_parameters(self, data):
"include_positions",
"include_self",
]

for p in pars:
assert p in list(self.params.keys()), f"Parameter {p} is not specified!"

Expand Down Expand Up @@ -197,7 +197,7 @@ def setup_layers(self):
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
norm=self.params['batch_norm']
norm=self.params["batch_norm"],
)

def forward(self, data, n_id, adjs=None):
Expand All @@ -219,11 +219,11 @@ def forward(self, data, n_id, adjs=None):
x = geometry.global_to_local_frame(x, data.gauges, reverse=True)
else:
x = self.diffusion(x, data.L, method="spectral")

# local gauges
if self.params["inner_product_features"]:
x = geometry.global_to_local_frame(x, data.gauges)

# restrict to current batch
x = x[n_id]
mask = mask[n_id]
Expand Down Expand Up @@ -262,7 +262,7 @@ def forward(self, data, n_id, adjs=None):

if self.params["include_positions"]:
out = torch.hstack([data.pos[n_id[: last_size[1]]], out])

emb = self.enc(out)

if self.params["emb_norm"]: # spherical output
Expand Down Expand Up @@ -415,7 +415,7 @@ def load_model(self, loadpath):
self._epoch = checkpoint["epoch"]
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer_state_dict = checkpoint["optimizer_state_dict"]
if hasattr(self, 'losses'):
if hasattr(self, "losses"):
self.losses = checkpoint["losses"]

def save_model(self, optimizer, losses, outdir=None, best=False, timestamp=""):
Expand Down
18 changes: 6 additions & 12 deletions MARBLE/plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Plotting module."""
import torch
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
Expand Down Expand Up @@ -538,7 +537,7 @@ def trajectories(
lw (int): Line width
ms (int): Marker size
scale (float): Scaling of arrows
arrow_spacing (int): How many timesteps apart are the arrows spaced.
arrow_spacing (int): How many timesteps apart are the arrows spaced.
axes_visible (bool): Whether to display axes
alpha (float): transparancy of the markers
Expand Down Expand Up @@ -568,14 +567,7 @@ def trajectories(
alpha=alpha,
)
else:
ax.plot(
X[:, 0],
X[:, 1],
c=c,
linewidth=lw,
markersize=ms,
alpha=alpha
)
ax.plot(X[:, 0], X[:, 1], c=c, linewidth=lw, markersize=ms, alpha=alpha)
if ">" in style:
skip = (slice(None, None, arrow_spacing), slice(None))
X, V = X[skip], V[skip]
Expand Down Expand Up @@ -613,7 +605,9 @@ def trajectories(
X, V = X[skip], V[skip]
plot_arrows(X, V, ax, c, width=lw, scale=scale)
else:
raise Exception('Data dimension is: {}. It needs to be 2 or 3 to allow plotting.'.format(dim))
raise Exception(
"Data dimension is: {}. It needs to be 2 or 3 to allow plotting.".format(dim)
)

set_axes(ax, axes_visible=axes_visible)

Expand Down Expand Up @@ -696,7 +690,7 @@ def create_axis(*args, fig=None):
elif dim == 3:
ax = fig.add_subplot(*args, projection="3d")
else:
raise Exception('Data dimension is {}. We can only plot 2D or 3D data.'.format(dim))
raise Exception("Data dimension is {}. We can only plot 2D or 3D data.".format(dim))

return fig, ax

Expand Down
2 changes: 1 addition & 1 deletion MARBLE/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def embed_in_2D(data, embed_typ="umap", manifold=None, seed=0):
else:
data.emb_2D, data.manifold = g.embed(emb, embed_typ=embed_typ, manifold=manifold, seed=seed)

return data
return data
11 changes: 6 additions & 5 deletions MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Preprocessing module."""
import torch
from torch_geometric.data import Batch, Data
from torch_geometric.data import Batch
from torch_geometric.data import Data
from torch_geometric.transforms import RandomNodeSplit

from MARBLE import geometry as g
Expand Down Expand Up @@ -47,7 +48,7 @@ def construct_dataset(
anchor = [torch.tensor(p).float() for p in utils.to_list(anchor)]
vector = [torch.tensor(x).float() for x in utils.to_list(vector)]
num_node_features = vector[0].shape[1]

if label is None:
label = [torch.arange(len(p)) for p in utils.to_list(anchor)]
else:
Expand All @@ -70,12 +71,12 @@ def construct_dataset(
else:
start_idx = 0
sample_ind, _ = g.furthest_point_sampling(a, spacing=spacing, start_idx=start_idx)
sample_ind, _ = torch.sort(sample_ind) #this will make postprocessing easier
sample_ind, _ = torch.sort(sample_ind) # this will make postprocessing easier
a_, v_, l_, m_ = a[sample_ind], v[sample_ind], l[sample_ind], m[sample_ind]

# fit graph to point cloud
edge_index, edge_weight = g.fit_graph(a_, graph_type=graph_type, par=k, delta=delta)

# define data object
data_ = Data(
pos=a_,
Expand Down Expand Up @@ -104,7 +105,7 @@ def construct_dataset(
return _compute_geometric_objects(
batch,
local_gauges=local_gauges,
n_geodesic_nb=k*frac_geodesic_nb,
n_geodesic_nb=k * frac_geodesic_nb,
var_explained=var_explained,
)

Expand Down
15 changes: 8 additions & 7 deletions MARBLE/smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch


def scalar_diffusion(x, t, method="matrix_exp", par=None):
"""Scalar diffusion."""
if len(x.shape) == 1:
Expand All @@ -12,7 +13,7 @@ def scalar_diffusion(x, t, method="matrix_exp", par=None):
if par.is_sparse:
par = par.to_dense()
return torch.matrix_exp(-t * par.to_dense()).mm(x)

if method == "spectral":
assert (
isinstance(par, (list, tuple)) and len(par) == 2
Expand All @@ -30,8 +31,8 @@ def scalar_diffusion(x, t, method="matrix_exp", par=None):
# Transform back to per-vertex
return evecs.mm(x_diffuse_spec)

raise NotImplementedError
raise NotImplementedError


def vector_diffusion(x, t, Lc, L=None, method="spectral", normalise=True):
"""Vector diffusion."""
Expand All @@ -54,10 +55,10 @@ def vector_diffusion(x, t, Lc, L=None, method="spectral", normalise=True):
out = out.view(x.shape)

if normalise:
assert L is not None, 'Need Laplacian for normalised diffusion!'
assert L is not None, "Need Laplacian for normalised diffusion!"
x_abs = x.norm(dim=-1, p=2, keepdim=True)
out_abs = scalar_diffusion(x_abs, t, method, L)
ind = scalar_diffusion(torch.ones(x.shape[0],1), t, method, L)
out = out*out_abs/(ind*out.norm(dim=-1, p=2, keepdim=True))
ind = scalar_diffusion(torch.ones(x.shape[0], 1), t, method, L)
out = out * out_abs / (ind * out.norm(dim=-1, p=2, keepdim=True))

return out
return out
4 changes: 2 additions & 2 deletions MARBLE/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def move_to_gpu(model, data, adjs=None):
data.mask = data.mask.to(device)

if hasattr(data, "L"):
if len(data.L)==2:
if len(data.L) == 2:
data.L = [_l.to(device) for _l in data.L]
else:
data.L = data.L.to(device)
else:
data.L = None

if hasattr(data, "Lc"):
if len(data.Lc)==2:
if len(data.Lc) == 2:
data.Lc = [_l.to(device) for _l in data.Lc]
else:
data.Lc = data.Lc.to(device)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_grad_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_gauges(plot=False):
x = np.vstack([xv.flatten(), yv.flatten()]).T

y = f1(x, alpha)
y = torch.tensor(y)
# y = torch.tensor(y)

data = construct_dataset(x, y, graph_type="cknn", k=k)
gauges = data.gauges
Expand All @@ -45,7 +45,7 @@ def test_gauges(plot=False):
K = [utils.to_SparseTensor(_K.coalesce().indices(), value=_K.coalesce().values()) for _K in K]

assert_array_almost_equal(
K[0].to_scipy().toarray()[:5, :5],
K[0].to_dense()[:5, :5],
np.array(
[
[-1.0, 0.25, 0.5, 0.0, 0.0],
Expand All @@ -59,7 +59,7 @@ def test_gauges(plot=False):
)

grad = AnisoConv()
der = grad(y, K)
der = grad(torch.tensor(y), K)
assert_array_almost_equal(
der.numpy()[:10],
np.array(
Expand Down
10 changes: 5 additions & 5 deletions tests/test_vector_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def test_diffusion(plot=False):
data.x.detach().numpy()[:5],
np.array(
[
[0.4629867, 0.124888],
[-0.03240441, 0.46744648],
[-0.29019982, 0.40529823],
[0.53165627, -0.00788487],
[0.288114, 0.34210885],
[0.8945822, 0.2413084],
[-0.06356261, 0.9169159],
[-0.5462601, 0.7629167],
[0.9424986, -0.01397797],
[0.57801795, 0.68634313],
]
),
decimal=5,
Expand Down

0 comments on commit 4a98923

Please sign in to comment.