Skip to content

Commit

Permalink
Merge pull request #8 from farhadmd7/main
Browse files Browse the repository at this point in the history
Update scikit-learn imports
  • Loading branch information
atong01 authored Nov 29, 2024
2 parents 7699f7f + 76c446e commit 26f147c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
17 changes: 13 additions & 4 deletions DiffusionEMD/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sklearn.datasets as skd
import sklearn.metrics
from sklearn.neighbors import kneighbors_graph
from sklearn.neighbors import radius_neighbors_graph
import ot
import pygsp

Expand Down Expand Up @@ -72,7 +73,7 @@ def get_graph(self):
return self.graph

class Line(Dataset):
def __init__(self, n_points, random_state=42):
def __init__(self, n_points, epsilon=0.1, random_state=42):
super().__init__()
self.n_points = n_points
N = n_points
Expand All @@ -83,11 +84,19 @@ def __init__(self, n_points, random_state=42):
# [np.cos(2 * np.pi * self.X[:, 0]), np.sin(2 * np.pi * self.X[:, 0])],
# axis=1
# )
self.graph = pygsp.graphs.NNGraph(
self.X, epsilon=0.1, NNtype="radius", rescale=False, center=False
)
self.graph = self.create_radius_graph(self.X, epsilon)
self.labels = np.eye(N)

def create_radius_graph(self, X, epsilon):
"""
Create a graph where each node is connected to all other nodes within a certain radius.
"""
adjacency_matrix = radius_neighbors_graph(X, radius=epsilon, mode='connectivity', include_self=False)

# Create the pygsp graph using the adjacency matrix
pygsp_graph = pygsp.graphs.Graph(adjacency_matrix)
return pygsp_graph

def get_graph(self):
return self.graph

Expand Down
3 changes: 2 additions & 1 deletion DiffusionEMD/metric_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_X_y, check_is_fitted
from sklearn.neighbors import KDTree, BallTree, DistanceMetric
from sklearn.neighbors import KDTree, BallTree
from sklearn.metrics import DistanceMetric
from sklearn.cluster import MiniBatchKMeans
from scipy.sparse import coo_matrix

Expand Down

0 comments on commit 26f147c

Please sign in to comment.