Skip to content

Commit

Permalink
Update preprocessing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai authored Dec 11, 2023
1 parent 017e6cf commit b725056
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def construct_dataset(
number_of_resamples=1,
var_explained=0.9,
local_gauges=False,
start_idx=None,
):
"""Construct PyG dataset from node positions and features.
Expand All @@ -40,6 +41,7 @@ def construct_dataset(
var_explained: fraction of variance explained by the local gauges
local_gauges: is True, it will try to compute local gauges if it can (signal dim is > 2,
embedding dimension is > 2 or dim embedding is not dim of manifold)
start_idx: Specify for reproducibility in the furthest point sampling. The default is None, which means a random starting vertex.
"""

anchor = [torch.tensor(p).float() for p in utils.to_list(anchor)]
Expand All @@ -63,7 +65,8 @@ def construct_dataset(
for i, (a, v, l, m) in enumerate(zip(anchor, vector, label, mask)):
for _ in range(number_of_resamples):
# even sampling of points
start_idx = torch.randint(low=0, high=len(a), size=(1,))
start_idx is None:
start_idx = torch.randint(low=0, high=len(a), size=(1,))
sample_ind, _ = g.furthest_point_sampling(a, spacing=spacing, start_idx=start_idx)
sample_ind, _ = torch.sort(sample_ind) #this will make postprocessing easier
a_, v_, l_, m_ = a[sample_ind], v[sample_ind], l[sample_ind], m[sample_ind]
Expand Down

0 comments on commit b725056

Please sign in to comment.