Skip to content

Commit

Permalink
Merge pull request #755 from wreise/tests_kernel_metrics
Browse files Browse the repository at this point in the history
Tests for kernels and distances
  • Loading branch information
VincentRouvreau authored Feb 13, 2023
2 parents a54745f + 20116fb commit f5f9876
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 19 deletions.
15 changes: 8 additions & 7 deletions src/python/gudhi/representations/kernel_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _persistence_scale_space_kernel(D1, D2, kernel_approx=None, bandwidth=1.):
weight_pss = lambda x: 1 if x[1] >= x[0] else -1
return 0.5 * _persistence_weighted_gaussian_kernel(DD1, DD2, weight=weight_pss, kernel_approx=kernel_approx, bandwidth=bandwidth)


def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", n_jobs=None, **kwargs):
"""
This function computes the kernel matrix between two lists of persistence diagrams given as numpy arrays of shape (nx2).
Expand All @@ -79,7 +80,7 @@ def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein",
if kernel == "sliced_wasserstein":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"], n_jobs=n_jobs) / kwargs["bandwidth"])
elif kernel == "persistence_fisher":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth"], n_jobs=n_jobs) / kwargs["bandwidth_fisher"])
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth_fisher"], n_jobs=n_jobs) / kwargs["bandwidth"])
elif kernel == "persistence_scale_space":
return _pairwise(pairwise_kernels, False, XX, YY, metric=_sklearn_wrapper(_persistence_scale_space_kernel, X, Y, **kwargs), n_jobs=n_jobs)
elif kernel == "persistence_weighted_gaussian":
Expand Down Expand Up @@ -123,7 +124,7 @@ def transform(self, X):
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise sliced Wasserstein kernel values.
numpy array of shape (number of diagrams in X) x (number of diagrams in **diagrams**): matrix of pairwise sliced Wasserstein kernel values.
"""
return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="sliced_wasserstein", bandwidth=self.bandwidth, num_directions=self.num_directions, n_jobs=self.n_jobs)

Expand All @@ -138,7 +139,7 @@ def __call__(self, diag1, diag2):
Returns:
float: sliced Wasserstein kernel value.
"""
return np.exp(-_sliced_wasserstein_distance(diag1, diag2, num_directions=self.num_directions)) / self.bandwidth
return np.exp(-_sliced_wasserstein_distance(diag1, diag2, num_directions=self.num_directions) / self.bandwidth)

class PersistenceWeightedGaussianKernel(BaseEstimator, TransformerMixin):
"""
Expand Down Expand Up @@ -177,7 +178,7 @@ def transform(self, X):
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence weighted Gaussian kernel values.
numpy array of shape (number of diagrams in X) x (number of diagrams in **diagrams**): matrix of pairwise persistence weighted Gaussian kernel values.
"""
return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_weighted_gaussian", bandwidth=self.bandwidth, weight=self.weight, kernel_approx=self.kernel_approx, n_jobs=self.n_jobs)

Expand Down Expand Up @@ -229,7 +230,7 @@ def transform(self, X):
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence scale space kernel values.
numpy array of shape (number of diagrams in X) x (number of diagrams in **diagrams**): matrix of pairwise persistence scale space kernel values.
"""
return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_scale_space", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx, n_jobs=self.n_jobs)

Expand Down Expand Up @@ -283,7 +284,7 @@ def transform(self, X):
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher kernel values.
numpy array of shape (number of diagrams in X) x (number of diagrams in **diagrams**): matrix of pairwise persistence Fisher kernel values.
"""
return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_fisher", bandwidth=self.bandwidth, bandwidth_fisher=self.bandwidth_fisher, kernel_approx=self.kernel_approx, n_jobs=self.n_jobs)

Expand All @@ -298,5 +299,5 @@ def __call__(self, diag1, diag2):
Returns:
float: persistence Fisher kernel value.
"""
return np.exp(-_persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)) / self.bandwidth_fisher
return np.exp(-_persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth_fisher, kernel_approx=self.kernel_approx) / self.bandwidth)

75 changes: 63 additions & 12 deletions src/python/test/test_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,71 @@ def _n_diags(n):
l.append(a)
return l


def test_multiple():
metrics_dict = { # (class, metric_kwargs, tolerance_pytest_approx)
"bottleneck": (BottleneckDistance(epsilon=0.00001),
dict(e=0.00001),
dict(abs=1e-5)),
"wasserstein": (WassersteinDistance(order=2, internal_p=2, n_jobs=4),
dict(order=2, internal_p=2, n_jobs=4),
dict(rel=1e-3)),
"sliced_wasserstein": (SlicedWassersteinDistance(num_directions=100, n_jobs=4),
dict(num_directions=100),
dict(rel=1e-3)),
"persistence_fisher": (PersistenceFisherDistance(bandwidth=1., n_jobs=4),
dict(bandwidth=1., n_jobs=4),
dict(abs=1e-5)),
}


def test_distance_transform_consistency():
l1 = _n_diags(9)
l2 = _n_diags(11)
l1b = l1.copy()
d1 = pairwise_persistence_diagram_distances(l1, e=0.00001, n_jobs=4)
d2 = BottleneckDistance(epsilon=0.00001).fit_transform(l1)
d3 = pairwise_persistence_diagram_distances(l1, l1b, e=0.00001, n_jobs=4)
assert d1 == pytest.approx(d2)
assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2)
d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1)
print(d1.shape, d2.shape)
assert d1 == pytest.approx(d2, rel=0.02)
for metricName, (metricClass, metricParams, tolerance) in metrics_dict.items():
d1 = pairwise_persistence_diagram_distances(l1, metric=metricName, **metricParams)
d2 = metricClass.fit_transform(l1)
assert d1 == pytest.approx(d2)
d3 = pairwise_persistence_diagram_distances(l1, l1b, metric=metricName, **metricParams)
assert d3 == pytest.approx(d2, **tolerance) # Because of 0 entries (on the diagonal)
d4 = metricClass.fit(l1).transform(l1b)
assert d4 == pytest.approx(d2, **tolerance)


kernel_dict = {
"sliced_wasserstein": (SlicedWassersteinKernel(num_directions=10, bandwidth=4., n_jobs=4),
dict(num_directions=10), dict(rel=1e-3)),
"persistence_fisher": (PersistenceFisherKernel(bandwidth_fisher=3., bandwidth=1.),
dict(bandwidth=3.), # corresponds to bandwidth_fisher in the kernel class
dict(rel=1e-3)),
"persistence_weighted_gaussian": (PersistenceWeightedGaussianKernel(bandwidth=4.,
weight=lambda x: x[1]-x[0]),
dict(bandwidth=4., weight=lambda x: x[1]-x[0]),
dict(rel=1e-3)),
"persistence_scale_space": (PersistenceScaleSpaceKernel(bandwidth=4.),
dict(bandwidth=4.),
dict(rel=1e-3)),
}
def test_kernel_from_distance():
l1, l2 = _n_diags(9), _n_diags(11)
for kernelName in ["sliced_wasserstein", "persistence_fisher"]:
kernelClass, kernelParams, tolerance = kernel_dict[kernelName]
f1 = kernelClass.fit_transform(l1)
d1 = pairwise_persistence_diagram_distances(l1, metric=kernelName, **kernelParams)
assert np.exp(-d1/kernelClass.bandwidth == pytest.approx(f1, **tolerance))

def test_kernel_distance_consistency():
l1, l2 = _n_diags(9), _n_diags(11)
for kernelName, (kernelClass, kernelParams, tolerance) in kernel_dict.items():
_ = kernelClass.fit(l1)
f2 = kernelClass.transform(l2)
f12 = np.array([[kernelClass(l1_, l2_) for l1_ in l1] for l2_ in l2])
assert f12 == pytest.approx(f2, **tolerance)

def test_sliced_wasserstein_distance_value():
diag1 = np.array([[0., 1.], [0., 2.]])
diag2 = np.array([[1., 0.]])
SWD = SlicedWassersteinDistance(num_directions=2)
distance = SWD(diag1, diag2)
assert distance == pytest.approx(2., abs=1e-8)


# Test sorted values as points order can be inverted, and sorted test is not documentation-friendly
Expand Down

0 comments on commit f5f9876

Please sign in to comment.