diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index 0d072677..e0d5af3d 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -21,6 +21,7 @@ def construct_dataset( number_of_resamples=1, var_explained=0.9, local_gauges=False, + start_idx=None, ): """Construct PyG dataset from node positions and features. @@ -40,6 +41,7 @@ def construct_dataset( var_explained: fraction of variance explained by the local gauges local_gauges: is True, it will try to compute local gauges if it can (signal dim is > 2, embedding dimension is > 2 or dim embedding is not dim of manifold) + start_idx: Specify for reproducibility in the furthest point sampling. The default is None, which means a random starting vertex. """ anchor = [torch.tensor(p).float() for p in utils.to_list(anchor)] @@ -63,7 +65,8 @@ def construct_dataset( for i, (a, v, l, m) in enumerate(zip(anchor, vector, label, mask)): for _ in range(number_of_resamples): # even sampling of points - start_idx = torch.randint(low=0, high=len(a), size=(1,)) + start_idx is None: + start_idx = torch.randint(low=0, high=len(a), size=(1,)) sample_ind, _ = g.furthest_point_sampling(a, spacing=spacing, start_idx=start_idx) sample_ind, _ = torch.sort(sample_ind) #this will make postprocessing easier a_, v_, l_, m_ = a[sample_ind], v[sample_ind], l[sample_ind], m[sample_ind]