From 56a25404d15aa1ce384d6205396aac9b84177e39 Mon Sep 17 00:00:00 2001 From: agosztolai Date: Thu, 23 May 2024 14:18:37 +0200 Subject: [PATCH] small bug with using devices --- MARBLE/geometry.py | 2 +- MARBLE/main.py | 5 +++-- MARBLE/smoothing.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/MARBLE/geometry.py b/MARBLE/geometry.py index dfbc3352..dc07b870 100644 --- a/MARBLE/geometry.py +++ b/MARBLE/geometry.py @@ -351,7 +351,7 @@ def manifold_dimension(Sigma, frac_explained=0.9): var_exp = Sigma.mean(0) - Sigma.std(0) dim_man = torch.where(var_exp >= frac_explained)[0][0] + 1 - print("\nFraction of variance explained: ", var_exp) + print("\nFraction of variance explained: ", var_exp.tolist()) return int(dim_man) diff --git a/MARBLE/main.py b/MARBLE/main.py index e10658d2..4f58647d 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -319,7 +319,6 @@ def batch_loss(self, data, loader, train=False, verbose=False, optimizer=None): for batch in tqdm(loader, disable=not verbose): _, n_id, adjs = batch adjs = [adj.to(data.x.device) for adj in utils.to_list(adjs)] - emb, mask = self.forward(data, n_id, adjs) loss = self.loss(emb, mask) cum_loss += float(loss) @@ -379,7 +378,9 @@ def fit(self, data, outdir=None, verbose=False): train_loss, optimizer = self.batch_loss( data, train_loader, train=True, verbose=verbose, optimizer=optimizer ) - val_loss, _ = self.batch_loss(data, val_loader, verbose=verbose) + val_loss, _ = self.batch_loss( + data, val_loader, verbose=verbose + ) scheduler.step(train_loss) print( diff --git a/MARBLE/smoothing.py b/MARBLE/smoothing.py index 16ab1d9b..072359c4 100644 --- a/MARBLE/smoothing.py +++ b/MARBLE/smoothing.py @@ -57,7 +57,7 @@ def vector_diffusion(x, t, Lc, L=None, method="spectral", normalise=True): assert L is not None, "Need Laplacian for normalised diffusion!" x_abs = x.norm(dim=-1, p=2, keepdim=True) out_abs = scalar_diffusion(x_abs, t, method, L) - ind = scalar_diffusion(torch.ones(x.shape[0], 1), t, method, L) + ind = scalar_diffusion(torch.ones(x.shape[0], 1).to(x.device), t, method, L) out = out * out_abs / (ind * out.norm(dim=-1, p=2, keepdim=True)) return out