diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index d05844eb..b4bc2cdf 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -24,7 +24,7 @@ def construct_dataset( local_gauges=False, seed=None, metric="euclidean", - eigendecomposition=True + number_of_eigenvectors=None ): """Construct PyG dataset from node positions and features. @@ -46,7 +46,7 @@ def construct_dataset( seed: Specify for reproducibility in the furthest point sampling. The default is None, which means a random starting vertex. metric: metric used to fit proximity graph - eigendecomposition: perform eigendecomposition (needed for diffusion). + number_of_eigenvectors: integer number of eigenvectors to use. Default: None, meaning use all. """ anchor = [torch.tensor(a).float() for a in utils.to_list(anchor)] @@ -115,7 +115,7 @@ def construct_dataset( local_gauges=local_gauges, n_geodesic_nb=k * frac_geodesic_nb, var_explained=var_explained, - eigendecomposition=eigendecomposition + number_of_eigenvectors=number_of_eigenvectors ) @@ -124,7 +124,7 @@ def _compute_geometric_objects( n_geodesic_nb=10, var_explained=0.9, local_gauges=False, - eigendecomposition=True + number_of_eigenvectors=None ): """ Compute geometric objects used later: local gauges, Levi-Civita connections @@ -135,6 +135,7 @@ def _compute_geometric_objects( n_geodesic_nb: number of geodesic neighbours to fit the tangent spaces to var_explained: fraction of variance explained by the local gauges local_gauges: whether to use local or global gauges + number_of_eigenvectors: integer number of eigenvectors to use. Default: None, meaning use all. Returns: data: pytorch geometric data object with the following new attributes @@ -194,10 +195,9 @@ def _compute_geometric_objects( kernels = g.gradient_op(data.pos, data.edge_index, gauges) Lc = None - if eigendecomposition: - print("\n---- Computing eigendecomposition ... ", end="") - L = g.compute_eigendecomposition(L) - Lc = g.compute_eigendecomposition(Lc) + print("\n---- Computing eigendecomposition ... ", end="") + L = g.compute_eigendecomposition(L, k=number_of_eigenvectors) + Lc = g.compute_eigendecomposition(Lc, k=number_of_eigenvectors) data.kernels = [ utils.to_SparseTensor(K.coalesce().indices(), value=K.coalesce().values()) for K in kernels