diff --git a/DiffusionEMD/dataset.py b/DiffusionEMD/dataset.py index c6fd297..d34dcb4 100644 --- a/DiffusionEMD/dataset.py +++ b/DiffusionEMD/dataset.py @@ -9,6 +9,7 @@ import sklearn.datasets as skd import sklearn.metrics from sklearn.neighbors import kneighbors_graph +from sklearn.neighbors import radius_neighbors_graph import ot import pygsp @@ -72,7 +73,7 @@ def get_graph(self): return self.graph class Line(Dataset): - def __init__(self, n_points, random_state=42): + def __init__(self, n_points, epsilon=0.1, random_state=42): super().__init__() self.n_points = n_points N = n_points @@ -83,11 +84,19 @@ def __init__(self, n_points, random_state=42): # [np.cos(2 * np.pi * self.X[:, 0]), np.sin(2 * np.pi * self.X[:, 0])], # axis=1 # ) - self.graph = pygsp.graphs.NNGraph( - self.X, epsilon=0.1, NNtype="radius", rescale=False, center=False - ) + self.graph = self.create_radius_graph(self.X, epsilon) self.labels = np.eye(N) + def create_radius_graph(self, X, epsilon): + """ + Create a graph where each node is connected to all other nodes within a certain radius. + """ + adjacency_matrix = radius_neighbors_graph(X, radius=epsilon, mode='connectivity', include_self=False) + + # Create the pygsp graph using the adjacency matrix + pygsp_graph = pygsp.graphs.Graph(adjacency_matrix) + return pygsp_graph + def get_graph(self): return self.graph diff --git a/DiffusionEMD/metric_tree.py b/DiffusionEMD/metric_tree.py index fc060db..55153bd 100644 --- a/DiffusionEMD/metric_tree.py +++ b/DiffusionEMD/metric_tree.py @@ -10,7 +10,8 @@ import numpy as np from sklearn.base import BaseEstimator from sklearn.utils.validation import check_X_y, check_is_fitted -from sklearn.neighbors import KDTree, BallTree, DistanceMetric +from sklearn.neighbors import KDTree, BallTree +from sklearn.metrics import DistanceMetric from sklearn.cluster import MiniBatchKMeans from scipy.sparse import coo_matrix