Skip to content

Commit

Permalink
code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Dec 12, 2023
1 parent 4a98923 commit 86455c8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
Binary file added .coverage-py
Binary file not shown.
Binary file added .coverage-py.SV-87M-007.73249.019415
Binary file not shown.
16 changes: 9 additions & 7 deletions MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,30 @@ def construct_dataset(
graph_type: type of nearest-neighbours graph: cknn (default), knn or radius
k: number of nearest-neighbours to construct the graph
delta: argument for cknn graph construction to decide the radius for each points.
n_eigenvalues: number of eigenvalue/eigenvector pairs to compute (None means all, but this can be slow)
n_eigenvalues: number of eigenvalue/eigenvector pairs to compute (None means all,
but this can be slow)
frac_geodesic_nb: number of geodesic neighbours to fit the gauges to
to map to tangent space k*frac_geodesic_nb
stop_crit: stopping criterion for furthest point sampling
number_of_resamples: number of furthest point sampling runs to prevent bias (experimental)
var_explained: fraction of variance explained by the local gauges
local_gauges: is True, it will try to compute local gauges if it can (signal dim is > 2,
embedding dimension is > 2 or dim embedding is not dim of manifold)
seed: Specify for reproducibility in the furthest point sampling. The default is None, which means a random starting vertex.
seed: Specify for reproducibility in the furthest point sampling.
The default is None, which means a random starting vertex.
"""

anchor = [torch.tensor(p).float() for p in utils.to_list(anchor)]
vector = [torch.tensor(x).float() for x in utils.to_list(vector)]
anchor = [torch.tensor(a).float() for a in utils.to_list(anchor)]
vector = [torch.tensor(v).float() for v in utils.to_list(vector)]
num_node_features = vector[0].shape[1]

if label is None:
label = [torch.arange(len(p)) for p in utils.to_list(anchor)]
label = [torch.arange(len(a)) for a in utils.to_list(anchor)]
else:
label = [torch.tensor(l).float() for l in utils.to_list(label)]
label = [torch.tensor(lab).float() for lab in utils.to_list(label)]

if mask is None:
mask = [torch.zeros(len(p), dtype=torch.bool) for p in utils.to_list(anchor)]
mask = [torch.zeros(len(a), dtype=torch.bool) for a in utils.to_list(anchor)]
else:
mask = [torch.tensor(m) for m in utils.to_list(mask)]

Expand Down

0 comments on commit 86455c8

Please sign in to comment.