Skip to content

Commit

Permalink
small bug with using devices
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed May 23, 2024
1 parent 61949d9 commit 56a2540
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion MARBLE/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def manifold_dimension(Sigma, frac_explained=0.9):
var_exp = Sigma.mean(0) - Sigma.std(0)
dim_man = torch.where(var_exp >= frac_explained)[0][0] + 1

print("\nFraction of variance explained: ", var_exp)
print("\nFraction of variance explained: ", var_exp.tolist())

return int(dim_man)

Expand Down
5 changes: 3 additions & 2 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def batch_loss(self, data, loader, train=False, verbose=False, optimizer=None):
for batch in tqdm(loader, disable=not verbose):
_, n_id, adjs = batch
adjs = [adj.to(data.x.device) for adj in utils.to_list(adjs)]

emb, mask = self.forward(data, n_id, adjs)
loss = self.loss(emb, mask)
cum_loss += float(loss)
Expand Down Expand Up @@ -379,7 +378,9 @@ def fit(self, data, outdir=None, verbose=False):
train_loss, optimizer = self.batch_loss(
data, train_loader, train=True, verbose=verbose, optimizer=optimizer
)
val_loss, _ = self.batch_loss(data, val_loader, verbose=verbose)
val_loss, _ = self.batch_loss(
data, val_loader, verbose=verbose
)
scheduler.step(train_loss)

print(
Expand Down
2 changes: 1 addition & 1 deletion MARBLE/smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def vector_diffusion(x, t, Lc, L=None, method="spectral", normalise=True):
assert L is not None, "Need Laplacian for normalised diffusion!"
x_abs = x.norm(dim=-1, p=2, keepdim=True)
out_abs = scalar_diffusion(x_abs, t, method, L)
ind = scalar_diffusion(torch.ones(x.shape[0], 1), t, method, L)
ind = scalar_diffusion(torch.ones(x.shape[0], 1).to(x.device), t, method, L)
out = out * out_abs / (ind * out.norm(dim=-1, p=2, keepdim=True))

return out

0 comments on commit 56a2540

Please sign in to comment.