Skip to content

Commit

Permalink
Test fixed_point vs gradient_descent
Browse files Browse the repository at this point in the history
  • Loading branch information
clbonet committed Oct 19, 2024
1 parent be985d1 commit 9a43369
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@ def test_bures_wasserstein_barycenter(nx):
np.testing.assert_allclose(Cbdiag, Cdiag_cf, rtol=1e-2, atol=1e-2)


def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx):
n = 50
k = 10
X = []
y = []
m = []
C = []
for _ in range(k):
X_, y_ = make_data_classif('3gauss', n)
m_ = np.mean(X_, axis=0)[None, :]
C_ = np.cov(X_.T)
X.append(X_)
y.append(y_)
m.append(m_)
C.append(C_)
m = np.array(m)
C = np.array(C)
X = nx.from_numpy(*X)
m = nx.from_numpy(m)
C = nx.from_numpy(C)

mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method="fixed_point", log=False)
mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter(m, C, method="gradient_descent", log=False)

np.testing.assert_allclose(mb, mb2, atol=1e-5)
np.testing.assert_allclose(Cb, Cb2, atol=1e-5)


@pytest.mark.parametrize("bias", [True, False])
def test_empirical_bures_wasserstein_barycenter(nx, bias):
n = 50
Expand Down

0 comments on commit 9a43369

Please sign in to comment.