Skip to content

Commit

Permalink
Update preprocessing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai authored Aug 19, 2024
1 parent 56040d6 commit d9093c8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def construct_dataset(
local_gauges=False,
seed=None,
metric="euclidean",
eigendecomposition=True
number_of_eigenvectors=None
):
"""Construct PyG dataset from node positions and features.
Expand All @@ -46,7 +46,7 @@ def construct_dataset(
seed: Specify for reproducibility in the furthest point sampling.
The default is None, which means a random starting vertex.
metric: metric used to fit proximity graph
eigendecomposition: perform eigendecomposition (needed for diffusion).
number_of_eigenvectors: integer number of eigenvectors to use. Default: None, meaning use all.
"""

anchor = [torch.tensor(a).float() for a in utils.to_list(anchor)]
Expand Down Expand Up @@ -115,7 +115,7 @@ def construct_dataset(
local_gauges=local_gauges,
n_geodesic_nb=k * frac_geodesic_nb,
var_explained=var_explained,
eigendecomposition=eigendecomposition
number_of_eigenvectors=number_of_eigenvectors
)


Expand All @@ -124,7 +124,7 @@ def _compute_geometric_objects(
n_geodesic_nb=10,
var_explained=0.9,
local_gauges=False,
eigendecomposition=True
number_of_eigenvectors=None
):
"""
Compute geometric objects used later: local gauges, Levi-Civita connections
Expand All @@ -135,6 +135,7 @@ def _compute_geometric_objects(
n_geodesic_nb: number of geodesic neighbours to fit the tangent spaces to
var_explained: fraction of variance explained by the local gauges
local_gauges: whether to use local or global gauges
number_of_eigenvectors: integer number of eigenvectors to use. Default: None, meaning use all.
Returns:
data: pytorch geometric data object with the following new attributes
Expand Down Expand Up @@ -194,10 +195,9 @@ def _compute_geometric_objects(
kernels = g.gradient_op(data.pos, data.edge_index, gauges)
Lc = None

if eigendecomposition:
print("\n---- Computing eigendecomposition ... ", end="")
L = g.compute_eigendecomposition(L)
Lc = g.compute_eigendecomposition(Lc)
print("\n---- Computing eigendecomposition ... ", end="")
L = g.compute_eigendecomposition(L, k=number_of_eigenvectors)
Lc = g.compute_eigendecomposition(Lc, k=number_of_eigenvectors)

data.kernels = [
utils.to_SparseTensor(K.coalesce().indices(), value=K.coalesce().values()) for K in kernels
Expand Down

0 comments on commit d9093c8

Please sign in to comment.