From aef34dd9536334d429fbe4304887eae6d297a72e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 7 Aug 2023 14:40:48 +0200 Subject: [PATCH 01/13] gaussian gromov distance --- README.md | 5 + ot/backend.py | 25 +++++ ot/gaussian.py | 244 ++++++++++++++++++++++++++++++++++++++++++ test/test_backend.py | 4 + test/test_gaussian.py | 28 +++++ 5 files changed, 306 insertions(+) diff --git a/README.md b/README.md index a5660a988..fdcfc9551 100644 --- a/README.md +++ b/README.md @@ -324,3 +324,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR). [56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019. + +[57] Delon, J., Desolneux, A., & Salmona, A. (2022). [Gromov–Wasserstein +distances between Gaussian distributions](https://hal.science/hal-03197398v2/file/main.pdf). Journal of Applied Probability, 59(4), +1178-1198. + diff --git a/ot/backend.py b/ot/backend.py index 1b6ca606b..974234831 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -858,6 +858,16 @@ def sqrtm(self, a): """ raise NotImplementedError() + def eigh(self, a): + r""" + Computes the eigenvalues and eigenvectors of a symmetric tensor. + + This function follows the api from :any:`scipy.linalg.eigh`. + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html + """ + raise NotImplementedError() + def kl_div(self, p, q, eps=1e-16): r""" Computes the Kullback-Leibler divergence. @@ -1253,6 +1263,9 @@ def sqrtm(self, a): L, V = np.linalg.eigh(a) return (V * np.sqrt(L)[None, :]) @ V.T + def eigh(self, a): + return np.linalg.eigh(a) + def kl_div(self, p, q, eps=1e-16): return np.sum(p * np.log(p / q + eps)) @@ -1631,6 +1644,9 @@ def sqrtm(self, a): L, V = jnp.linalg.eigh(a) return (V * jnp.sqrt(L)[None, :]) @ V.T + def eigh(self, a): + return jnp.linalg.eigh(a) + def kl_div(self, p, q, eps=1e-16): return jnp.sum(p * jnp.log(p / q + eps)) @@ -2106,6 +2122,9 @@ def sqrtm(self, a): L, V = torch.linalg.eigh(a) return (V * torch.sqrt(L)[None, :]) @ V.T + def eigh(self, a): + return torch.linalg.eigh(a) + def kl_div(self, p, q, eps=1e-16): return torch.sum(p * torch.log(p / q + eps)) @@ -2495,6 +2514,9 @@ def sqrtm(self, a): L, V = cp.linalg.eigh(a) return (V * cp.sqrt(L)[None, :]) @ V.T + def eigh(self, a): + return cp.linalg.eigh(a) + def kl_div(self, p, q, eps=1e-16): return cp.sum(p * cp.log(p / q + eps)) @@ -2902,6 +2924,9 @@ def sqrtm(self, a): L, V = tf.linalg.eigh(a) return (V * tf.sqrt(L)[None, :]) @ V.T + def eigh(self, a): + return tf.linalg.eigh(a) + def kl_div(self, p, q, eps=1e-16): return tnp.sum(p * tnp.log(p / q + eps)) diff --git a/ot/gaussian.py b/ot/gaussian.py index 1a295567d..da399af8e 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -331,3 +331,247 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, else: W = bures_wasserstein_distance(mxs, mxt, Cs, Ct) return W + + +def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): + r""" Return the Gaussian Gromov-Wasserstein value from [57]. + + This function return the closed form value of the Gaussian Gromov-Wasserstein + distance between two Gaussian distributions + :math:`\mathcal{N}(\mu_s,\Sigma_s)` and :math:`\mathcal{N}(\mu_t,\Sigma_t)` + when the OT plan is assumed to be also Gaussian. See [57] Theorem 4.1 for + more details. + + Parameters + ---------- + Cov_s : array-like (d,d) + covariance of the source distribution + Cov_t : array-like (d,d) + covariance of the target distribution + + + Returns + ------- + G : float + Gaussian Gromov-Wasserstein distance + + + .. _references-gaussien_gromov_wasserstein_distance: + References + ---------- + [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + distances between Gaussian distributions. Journal of Applied Probability, + 59(4), + 1178-1198. + """ + + nx = get_backend(Cov_s, Cov_t) + + # ensure that Cov_s is the largest covariance matrix + # that is m >= n + if Cov_s.shape[0] < Cov_t.shape[0]: + Cov_s, Cov_t = Cov_t, Cov_s + + n = Cov_t.shape[0] + + # compte and sort eigenvalues decerasingly + d_s = nx.flip(nx.sort(nx.eigh(Cov_s)[0])) + d_t = nx.flip(nx.sort(nx.eigh(Cov_t)[0])) + + # compute the gaussien Gromov-Wasserstein distance + res = 4 * (nx.sum(d_s) - nx.sum(d_t))**2 + 8 * nx.sum((d_s[:n] - d_t)**2) + 8 * nx.sum((d_s[n:])**2) + if log: + log = {} + log['d_s'] = d_s + log['d_t'] = d_t + return nx.sqrt(res), log + else: + return nx.sqrt(res) + + +def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, + wt=None, log=False): + r"""Return Gaussian Gromov-Wasserstein distance between samples. + + The function estimates the Gaussian Gromov-Wasserstein distance between two + Gaussien distributions source :math:`\mu_s` and target :math:`\mu_t`, whose + parameters are estimated from the provided samples :math:`\mathcal{X}_s` and + :math:`\mathcal{X}_t`. See [57] Theorem 4.1 for more details. + + Parameters + ---------- + xs : array-like (ns,d) + samples in the source domain + xt : array-like (nt,d) + samples in the target domain + ws : array-like (ns,1), optional + weights for the source samples + wt : array-like (ns,1), optional + weights for the target samples + log : bool, optional + record log if True + + + Returns + ------- + G : float + Gaussian Gromov-Wasserstein distance + + + .. _references-gaussien_gromov_wasserstein: + References + ---------- + [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + distances between Gaussian distributions. Journal of Applied Probability, + 59(4), 1178-1198. + """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + + ds = xs.shape[1] + dt = xt.shape[1] + + if ws is None: + ws = nx.ones((xs.shape[0]), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0]), type_as=xt) / xt.shape[0] + + mxs = nx.dot(ws, xs) / nx.sum(ws) + mxt = nx.dot(wt, xt) / nx.sum(wt) + + xs = xs - mxs + xt = xt - mxt + + Cs = nx.dot((xs * ws[:, None]).T, xs) / nx.sum(ws) + Ct = nx.dot((xt * wt[:, None]).T, xt) / nx.sum(wt) + + if log: + G, log = gaussian_gromov_wasserstein_distance(Cs, Ct, log=log) + log['Cov_s'] = Cs + log['Cov_t'] = Ct + return G, log + else: + G = gaussian_gromov_wasserstein_distance(Cs, Ct) + return G + + +def gaussian_gromov_wasserstein_mapping(mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None, log=False): + r""" Return the Gaussian Gromov-Wasserstein mapping from [57]. + + This function return the closed form value of the Gaussian + Gromov-Wasserstein mapping between two Gaussian distributions + :math:`\mathcal{N}(\mu_s,\Sigma_s)` and :math:`\mathcal{N}(\mu_t,\Sigma_t)` + when the OT plan is assumed to be also Gaussian. See [57] Theorem 4.1 for + more details. + + Parameters + ---------- + mu_s : array-like (ds,) + mean of the source distribution + mu_t : array-like (dt,) + mean of the target distribution + Cov_s : array-like (ds,ds) + covariance of the source distribution + Cov_t : array-like (dt,dt) + covariance of the target distribution + log : bool, optional + record log if True + + + Returns + ------- + A : (dt, ds) array-like + Linear operator + b : (1, dt) array-like + bias + + + .. _references-gaussien_gromov_wasserstein_mapping: + References + ---------- + [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + distances between Gaussian distributions. Journal of Applied Probability, + 59(4), 1178-1198. + """ + + nx = get_backend(mu_s, mu_t, Cov_s, Cov_t) + + n = Cov_t.shape[0] + m = Cov_s.shape[0] + + # compte and sort eigenvalues/eigenvectors decreasingly + d_s, U_s = nx.eigh(Cov_s) + id_s = nx.flip(nx.argsort(d_s)) + ds, Us = d_s[id_s], U_s[:, id_s] + + d_t, U_t = nx.eigh(Cov_t) + id_t = nx.flip(nx.argsort(d_t)) + dt, Ut = d_t[id_t], U_t[:, id_t] + + if sign_eigs is None: + sign_eigs = nx.ones(min(m, n), type_as=mu_s) + + if m >= n: + A = nx.concatenate((nx.diag(sign_eigs * d_t / d_s[:n]), nx.zeros((n, m - n), type_as=mu_s)), axis=1).T + else: + A = nx.concatenate((nx.diag(sign_eigs * d_t[:m] / d_s), nx.zeros((n - m, m), type_as=mu_s)), axis=0).T + + A = nx.dot(nx.dot(U_s, A), Ut.T) + + # compute the gaussien Gromov-Wasserstein dis + b = mu_t - nx.dot(mu_s, A) + + if log: + log = {} + log['d_s'] = d_s + log['d_t'] = d_t + log['U_s'] = U_s + log['U_t'] = U_t + return A, b, log + else: + return A, b + + +def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, + wt=None, sign_eigs=None, log=False): + r"""Return Gaussian Gromov-Wasserstein mapping between samples. + + The function estimates the Gaussian Gromov-Wasserstein mapping between two + Gaussien distributions source :math:`\mu_s` and target :math:`\mu_t`, whose + parameters are estimated from the provided samples :math:`\mathcal{X}_s` and + :math:`\mathcal{X}_t`. See [57] Theorem 4.1 for more details. + + + Parameters + ---------- + xs : array-like (ns,ds) + samples in the source domain + xt : array-like (nt,dt) + samples in the target domain + ws : array-like (ns,1), optional + weights for the source samples + wt : array-like (ns,1), optional + weights for the target samples + sign_eigs : array-like (min(ds,dt),), optional + sign of the eigenvalues of the mapping matrix + log : bool, optional + record log if True + + + Returns + ------- + A : (dt, ds) array-like + Linear operator + b : (1, dt) array-like + bias + + .. _references-empirical_gaussian_gromov_wasserstein_mapping: + References + ---------- + [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + distances between Gaussian distributions. Journal of Applied Probability, + 59(4), 1178-1198. + """ + + pass diff --git a/test/test_backend.py b/test/test_backend.py index 8f7cd9ec1..b161746bf 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -588,6 +588,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matrix square root") + D, U = nx.eigh(SquareMb.T @ SquareMb) + lst_b.append(nx.to_numpy(nx.dot(U, nx.dot(nx.diag(D), U.T)))) + lst_name.append("eigh ") + A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1) lst_b.append(nx.to_numpy(A)) lst_name.append("Kullback-Leibler divergence") diff --git a/test/test_gaussian.py b/test/test_gaussian.py index be7a80651..7e6db1948 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -96,3 +96,31 @@ def test_empirical_bures_wasserstein_distance(nx, bias): np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("d_target", [1, 2, 3, 10]) +def test_gaussian_gromov_wasserstein_distance(nx, d_target): + ns = 400 + nt = 400 + + rng = np.random.RandomState(10) + Xs, ys = make_data_classif('3gauss', ns, random_state=rng) + Xt, yt = make_data_classif('3gauss2', nt, random_state=rng) + Xt = np.concatenate((Xt, rng.normal(0, 1, (nt, 8))), axis=1) + Xt = Xt[:, 0:d_target].reshape((nt, d_target)) + + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + Cs = np.cov(Xs.T) + Ct = np.cov(Xt.T).reshape((d_target, d_target)) + + Xsb, Xtb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, Xt, ms, mt, Cs, Ct) + + Gb, log = ot.gaussian.gaussian_gromov_wasserstein_distance(Csb, Ctb, log=True) + Ge, log = ot.gaussian.empirical_gaussian_gromov_wasserstein_distance(Xsb, Xtb, log=True) + + # no log + Ge0 = ot.gaussian.empirical_gaussian_gromov_wasserstein_distance(Xsb, Xtb, log=False) + + np.testing.assert_allclose(nx.to_numpy(Gb), nx.to_numpy(Ge), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(nx.to_numpy(Ge), nx.to_numpy(Ge0), rtol=1e-2, atol=1e-2) From a65979b4afdd63464f07b62f78616707aee0f3a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 7 Aug 2023 14:55:05 +0200 Subject: [PATCH 02/13] gaussian gromov distance --- ot/gaussian.py | 33 ++++++++++++++++++++++++++++++++- test/test_gaussian.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index da399af8e..3c693b2e9 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -574,4 +574,35 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, 59(4), 1178-1198. """ - pass + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + + ds = xs.shape[1] + dt = xt.shape[1] + + if ws is None: + ws = nx.ones((xs.shape[0]), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0]), type_as=xt) / xt.shape[0] + + mxs = nx.dot(ws, xs) / nx.sum(ws) + mxt = nx.dot(wt, xt) / nx.sum(wt) + + xs = xs - mxs + xt = xt - mxt + + Cs = nx.dot((xs * ws[:, None]).T, xs) / nx.sum(ws) + Ct = nx.dot((xt * wt[:, None]).T, xt) / nx.sum(wt) + + if log: + + A, b, log = gaussian_gromov_wasserstein_mapping(mxs, mxt, Cs, Ct, sign_eigs=sign_eigs, log=log) + log['Cov_s'] = Cs + log['Cov_t'] = Ct + return A, b, log + + else: + A, b = gaussian_gromov_wasserstein_mapping(mxs, mxt, Cs, Ct, sign_eigs=sign_eigs) + return A, b + diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 7e6db1948..130a126d2 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -124,3 +124,34 @@ def test_gaussian_gromov_wasserstein_distance(nx, d_target): np.testing.assert_allclose(nx.to_numpy(Gb), nx.to_numpy(Ge), rtol=1e-2, atol=1e-2) np.testing.assert_allclose(nx.to_numpy(Ge), nx.to_numpy(Ge0), rtol=1e-2, atol=1e-2) + +@pytest.mark.parametrize("d_target", [1, 2, 3, 10]) +def test_gaussian_gromov_wasserstein_mapping(nx, d_target): + ns = 400 + nt = 400 + + rng = np.random.RandomState(10) + Xs, ys = make_data_classif('3gauss', ns, random_state=rng) + Xt, yt = make_data_classif('3gauss2', nt, random_state=rng) + Xt = np.concatenate((Xt, rng.normal(0, 1, (nt, 8))), axis=1) + Xt = Xt[:, 0:d_target].reshape((nt, d_target)) + + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + Cs = np.cov(Xs.T) + Ct = np.cov(Xt.T).reshape((d_target, d_target)) + + Xsb, Xtb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, Xt, ms, mt, Cs, Ct) + + A,b , log = ot.gaussian.gaussian_gromov_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) + Ae, be, loge = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(Xsb, Xtb, log=True) + + # no log + Ae0, be0 = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(Xsb, Xtb, log=False) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae)) + np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae0)) + np.testing.assert_allclose(Ct, Cst) \ No newline at end of file From 589f951a95ed06b565bfc3acd03db99925cc3e75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 10:26:33 +0200 Subject: [PATCH 03/13] debug mapping function --- ot/gaussian.py | 12 +++++++----- test/test_gaussian.py | 3 ++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 3c693b2e9..22a2a1b09 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -502,22 +502,24 @@ def gaussian_gromov_wasserstein_mapping(mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None # compte and sort eigenvalues/eigenvectors decreasingly d_s, U_s = nx.eigh(Cov_s) + print(d_s) id_s = nx.flip(nx.argsort(d_s)) - ds, Us = d_s[id_s], U_s[:, id_s] + d_s, U_s = d_s[id_s], U_s[:, id_s] + print(d_s) d_t, U_t = nx.eigh(Cov_t) id_t = nx.flip(nx.argsort(d_t)) - dt, Ut = d_t[id_t], U_t[:, id_t] + d_t, U_t = d_t[id_t], U_t[:, id_t] if sign_eigs is None: sign_eigs = nx.ones(min(m, n), type_as=mu_s) if m >= n: - A = nx.concatenate((nx.diag(sign_eigs * d_t / d_s[:n]), nx.zeros((n, m - n), type_as=mu_s)), axis=1).T + A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), nx.zeros((n, m - n), type_as=mu_s)), axis=1).T else: - A = nx.concatenate((nx.diag(sign_eigs * d_t[:m] / d_s), nx.zeros((n - m, m), type_as=mu_s)), axis=0).T + A = nx.concatenate((nx.diag(sign_eigs *nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), nx.zeros((n - m, m), type_as=mu_s)), axis=0).T - A = nx.dot(nx.dot(U_s, A), Ut.T) + A = nx.dot(nx.dot(U_s, A), U_t.T) # compute the gaussien Gromov-Wasserstein dis b = mu_t - nx.dot(mu_s, A) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 130a126d2..2c5af6fa2 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -154,4 +154,5 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae)) np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae0)) - np.testing.assert_allclose(Ct, Cst) \ No newline at end of file + if d_target <=2: + np.testing.assert_allclose(Ct, Cst) \ No newline at end of file From f5c2c6b608bf0bd0e8af991848e16dfe761ce886 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 10:27:17 +0200 Subject: [PATCH 04/13] cleanup debug and pep8 --- ot/gaussian.py | 5 +---- test/test_gaussian.py | 7 ++++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 22a2a1b09..de672e30d 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -502,10 +502,8 @@ def gaussian_gromov_wasserstein_mapping(mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None # compte and sort eigenvalues/eigenvectors decreasingly d_s, U_s = nx.eigh(Cov_s) - print(d_s) id_s = nx.flip(nx.argsort(d_s)) d_s, U_s = d_s[id_s], U_s[:, id_s] - print(d_s) d_t, U_t = nx.eigh(Cov_t) id_t = nx.flip(nx.argsort(d_t)) @@ -517,7 +515,7 @@ def gaussian_gromov_wasserstein_mapping(mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None if m >= n: A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), nx.zeros((n, m - n), type_as=mu_s)), axis=1).T else: - A = nx.concatenate((nx.diag(sign_eigs *nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), nx.zeros((n - m, m), type_as=mu_s)), axis=0).T + A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), nx.zeros((n - m, m), type_as=mu_s)), axis=0).T A = nx.dot(nx.dot(U_s, A), U_t.T) @@ -607,4 +605,3 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, else: A, b = gaussian_gromov_wasserstein_mapping(mxs, mxt, Cs, Ct, sign_eigs=sign_eigs) return A, b - diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 2c5af6fa2..a308503ca 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -125,6 +125,7 @@ def test_gaussian_gromov_wasserstein_distance(nx, d_target): np.testing.assert_allclose(nx.to_numpy(Gb), nx.to_numpy(Ge), rtol=1e-2, atol=1e-2) np.testing.assert_allclose(nx.to_numpy(Ge), nx.to_numpy(Ge0), rtol=1e-2, atol=1e-2) + @pytest.mark.parametrize("d_target", [1, 2, 3, 10]) def test_gaussian_gromov_wasserstein_mapping(nx, d_target): ns = 400 @@ -143,7 +144,7 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): Xsb, Xtb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, Xt, ms, mt, Cs, Ct) - A,b , log = ot.gaussian.gaussian_gromov_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) + A, b, log = ot.gaussian.gaussian_gromov_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) Ae, be, loge = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(Xsb, Xtb, log=True) # no log @@ -154,5 +155,5 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae)) np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae0)) - if d_target <=2: - np.testing.assert_allclose(Ct, Cst) \ No newline at end of file + if d_target <= 2: + np.testing.assert_allclose(Ct, Cst) From 5ef1fed0050167cb0fed09823d835d8d83b7dee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 10:30:59 +0200 Subject: [PATCH 05/13] add test for diferent source saizes --- test/test_gaussian.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index a308503ca..8e5ca048b 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -157,3 +157,12 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae0)) if d_target <= 2: np.testing.assert_allclose(Ct, Cst) + + # test the other way around (target to source) + Ai, bi, logi = ot.gaussian.gaussian_gromov_wasserstein_mapping(mtb, msb, Ctb, Csb, log=True) + + Xtt = nx.to_numpy(nx.dot(Xtb, Ai) + bi) + Ctt = np.cov(Xtt.T) + + if d_target >= 2: + np.testing.assert_allclose(Cs, Ctt) From 04e95908fb42b532adb158b37f2e7ad179c7d5f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 10:55:36 +0200 Subject: [PATCH 06/13] add transport classes ith GW --- .../plot_otda_linear_mapping.py | 4 +- ot/da.py | 97 ++++++++++++++++++- ot/gaussian.py | 25 ++--- test/test_da.py | 27 ++++++ 4 files changed, 135 insertions(+), 18 deletions(-) diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index 8284a2a93..03ebd2cca 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -13,6 +13,8 @@ # License: MIT License # sphinx_gallery_thumbnail_number = 2 + +#%% import os from pathlib import Path @@ -75,7 +77,7 @@ plt.plot(xs[:, 0], xs[:, 1], '+') plt.plot(xt[:, 0], xt[:, 1], 'o') plt.plot(xst[:, 0], xst[:, 1], '+') - +plt.legend(('Source', 'Target', 'Transp'), loc=0) plt.show() ############################################################################## diff --git a/ot/da.py b/ot/da.py index 886b7ee52..c8ab966cf 100644 --- a/ot/da.py +++ b/ot/da.py @@ -19,7 +19,7 @@ from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots from .utils import list_to_array, check_params, BaseEstimator, deprecated from .unbalanced import sinkhorn_unbalanced -from .gaussian import empirical_bures_wasserstein_mapping +from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping from .optim import cg from .optim import gcg @@ -1360,6 +1360,101 @@ class label return transp_Xt +class LinearGWTransport(LinearTransport): + r""" OT Gaussian Gromov-Wasserstein linear operator between empirical distributions + + The function estimates the optimal linear operator that aligns the two + empirical distributions optimaly wrt the Gromov wassretsein distance. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[57] `. + + The linear operator from source to target :math:`M` + + .. math:: + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} + + where the matrix :math:`\mathbf{A}` and the vector :math:`\mathbf{b}` are + defined in :ref:`[57] `. + + + + Parameters + ---------- + log : bool, optional + record log if True + + + .. _references-lineargwtransport: + References + ---------- + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + distances between Gaussian distributions. Journal of Applied Probability, + 59(4), 1178-1198. + + """ + + def __init__(self, log=False, + distribution_estimation=distribution_estimation_uniform): + self.log = log + self.distribution_estimation = distribution_estimation + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + r"""Build a coupling matrix from source and target sets of samples + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + nx = self._get_backend(Xs, ys, Xt, yt) + + self.mu_s = self.distribution_estimation(Xs) + self.mu_t = self.distribution_estimation(Xt) + + # coupling estimation + returned_ = empirical_gaussian_gromov_wasserstein_mapping(Xs, Xt, + ws=self.mu_s, + wt=self.mu_t, + log=self.log) + + # deal with the value of log + if self.log: + self.A_, self.B_, self.log_ = returned_ + else: + self.A_, self.B_, = returned_ + self.log_ = dict() + + # re compute inverse mapping + returned_1_ = empirical_gaussian_gromov_wasserstein_mapping(Xt, Xs, + ws=self.mu_t, + wt=self.mu_s, + log=self.log) + if self.log: + self.A1_, self.B1_, self.log_1_ = returned_1_ + else: + self.A1_, self.B1_, = returned_1_ + self.log_ = dict() + + return self + + class SinkhornTransport(BaseTransport): """Domain Adaptation OT method based on Sinkhorn Algorithm diff --git a/ot/gaussian.py b/ot/gaussian.py index de672e30d..8a3c1dcbf 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -274,9 +274,9 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, samples in the target domain reg : float,optional regularization added to the diagonals of covariances (>0) - ws : array-like (ns,1), optional + ws : array-like (ns), optional weights for the source samples - wt : array-like (ns,1), optional + wt : array-like (ns), optional weights for the target samples bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) @@ -359,10 +359,9 @@ def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): .. _references-gaussien_gromov_wasserstein_distance: References ---------- - [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. Journal of Applied Probability, - 59(4), - 1178-1198. + 59(4), 1178-1198. """ nx = get_backend(Cov_s, Cov_t) @@ -404,9 +403,9 @@ def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, samples in the source domain xt : array-like (nt,d) samples in the target domain - ws : array-like (ns,1), optional + ws : array-like (ns), optional weights for the source samples - wt : array-like (ns,1), optional + wt : array-like (ns), optional weights for the target samples log : bool, optional record log if True @@ -421,16 +420,13 @@ def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, .. _references-gaussien_gromov_wasserstein: References ---------- - [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. Journal of Applied Probability, 59(4), 1178-1198. """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) - ds = xs.shape[1] - dt = xt.shape[1] - if ws is None: ws = nx.ones((xs.shape[0]), type_as=xs) / xs.shape[0] @@ -490,7 +486,7 @@ def gaussian_gromov_wasserstein_mapping(mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None .. _references-gaussien_gromov_wasserstein_mapping: References ---------- - [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. Journal of Applied Probability, 59(4), 1178-1198. """ @@ -569,7 +565,7 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, .. _references-empirical_gaussian_gromov_wasserstein_mapping: References ---------- - [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein + .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. Journal of Applied Probability, 59(4), 1178-1198. """ @@ -577,9 +573,6 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) - ds = xs.shape[1] - dt = xt.shape[1] - if ws is None: ws = nx.ones((xs.shape[0]), type_as=xs) / xs.shape[0] diff --git a/test/test_da.py b/test/test_da.py index c5f08d6b6..c95d48850 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -602,6 +602,33 @@ def test_linear_mapping_class(nx): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_linear_gw_mapping_class(nx): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + + otmap = ot.da.LinearGWTransport() + + otmap.fit(Xs=Xsb, Xt=Xtb) + assert hasattr(otmap, "A_") + assert hasattr(otmap, "B_") + assert hasattr(otmap, "A1_") + assert hasattr(otmap, "B1_") + + Xst = nx.to_numpy(otmap.transform(Xs=Xsb)) + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_jcpot_transport_class(nx): From 9db47926bfdae72ea5a9b4df7f52c410626658e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 11:09:02 +0200 Subject: [PATCH 07/13] upate linera mapping exmaple --- .../plot_otda_linear_mapping.py | 60 +++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index 03ebd2cca..7e7177d37 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -57,27 +57,43 @@ plt.figure(1, (5, 5)) plt.plot(xs[:, 0], xs[:, 1], '+') plt.plot(xt[:, 0], xt[:, 1], 'o') - +plt.legend(('Source', 'Target')) +plt.title('Source and target distributions') +plt.show() ############################################################################## # Estimate linear mapping and transport # ------------------------------------- + +# Gaussian (linear) Monge mapping estimation Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt) xst = xs.dot(Ae) + be +# Gaussian (linear) GW mapping estimation +Agw, bgw = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(xs, xt) + +xstgw = xs.dot(Agw) + bgw ############################################################################## # Plot transported samples # ------------------------ -plt.figure(1, (5, 5)) +plt.figure(2, (10, 5)) plt.clf() +plt.subplot(1, 2, 1) plt.plot(xs[:, 0], xs[:, 1], '+') plt.plot(xt[:, 0], xt[:, 1], 'o') plt.plot(xst[:, 0], xst[:, 1], '+') -plt.legend(('Source', 'Target', 'Transp'), loc=0) +plt.legend(('Source', 'Target', 'Transp. Monge'), loc=0) +plt.title('Transported samples with Monge') +plt.subplot(1, 2, 2) +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') +plt.plot(xstgw[:, 0], xstgw[:, 1], '+') +plt.legend(('Source', 'Target', 'Transp. GW'), loc=0) +plt.title('Transported samples with Gaussian GW') plt.show() ############################################################################## @@ -114,8 +130,8 @@ def minmax(img): # Estimate mapping and adapt # ---------------------------- +# Monge mapping mapping = ot.da.LinearTransport() - mapping.fit(Xs=X1, Xt=X2) @@ -125,6 +141,18 @@ def minmax(img): I1t = minmax(mat2im(xst, I1.shape)) I2t = minmax(mat2im(xts, I2.shape)) +# gaussian GW mapping + +mapping = ot.da.LinearGWTransport() +mapping.fit(Xs=X1, Xt=X2) + + +xstgw = mapping.transform(Xs=X1) +xtsgw = mapping.inverse_transform(Xt=X2) + +I1tgw = minmax(mat2im(xstgw, I1.shape)) +I2tgw = minmax(mat2im(xtsgw, I2.shape)) + # %% @@ -132,24 +160,34 @@ def minmax(img): # Plot transformed images # ----------------------- -plt.figure(2, figsize=(10, 7)) +plt.figure(3, figsize=(14, 7)) -plt.subplot(2, 2, 1) +plt.subplot(2, 3, 1) plt.imshow(I1) plt.axis('off') plt.title('Im. 1') -plt.subplot(2, 2, 2) +plt.subplot(2, 3, 4) plt.imshow(I2) plt.axis('off') plt.title('Im. 2') -plt.subplot(2, 2, 3) +plt.subplot(2, 3, 2) plt.imshow(I1t) plt.axis('off') -plt.title('Mapping Im. 1') +plt.title('Monge mapping Im. 1') -plt.subplot(2, 2, 4) +plt.subplot(2, 3, 5) plt.imshow(I2t) plt.axis('off') -plt.title('Inverse mapping Im. 2') +plt.title('Inverse Monge mapping Im. 2') + +plt.subplot(2, 3, 3) +plt.imshow(I1tgw) +plt.axis('off') +plt.title('Gaussian GW mapping Im. 1') + +plt.subplot(2, 3, 6) +plt.imshow(I2tgw) +plt.axis('off') +plt.title('Inverse Gaussian GW mapping Im. 2') From 5881e7906e4b1caeca6331259d180b2f21653039 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 11:22:37 +0200 Subject: [PATCH 08/13] pep8 --- ot/da.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ot/da.py b/ot/da.py index c8ab966cf..1437d459f 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1423,7 +1423,6 @@ class label self : object Returns self. """ - nx = self._get_backend(Xs, ys, Xt, yt) self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) From 16a147e0266d7bd8cfdabde363715fce6b2ed1ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 13:26:52 +0200 Subject: [PATCH 09/13] gaussian gromov with sskew signe alignment --- ot/backend.py | 24 +++++++++++++ ot/da.py | 11 +++--- ot/gaussian.py | 84 ++++++++++++++++++++++++++++++------------- test/test_backend.py | 4 +++ test/test_gaussian.py | 5 ++- 5 files changed, 97 insertions(+), 31 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 974234831..7b2fe875f 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -338,6 +338,15 @@ def minimum(self, a, b): """ raise NotImplementedError() + def sign(self, a): + r""" Returns an element-wise indication of the sign of a number. + + This function follows the api from :any:`numpy.sign` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html + """ + raise NotImplementedError() + def dot(self, a, b): r""" Returns the dot product of two tensors. @@ -1057,6 +1066,9 @@ def maximum(self, a, b): def minimum(self, a, b): return np.minimum(a, b) + def sign(self, a): + return np.sign(a) + def dot(self, a, b): return np.dot(a, b) @@ -1428,6 +1440,9 @@ def maximum(self, a, b): def minimum(self, a, b): return jnp.minimum(a, b) + def sign(self, a): + return jnp.sign(a) + def dot(self, a, b): return jnp.dot(a, b) @@ -1845,6 +1860,9 @@ def minimum(self, a, b): else: return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] + def sign(self, a): + return torch.sign(a) + def dot(self, a, b): return torch.matmul(a, b) @@ -2267,6 +2285,9 @@ def maximum(self, a, b): def minimum(self, a, b): return cp.minimum(a, b) + def sign(self, a): + return cp.sign(a) + def abs(self, a): return cp.abs(a) @@ -2664,6 +2685,9 @@ def maximum(self, a, b): def minimum(self, a, b): return tnp.minimum(a, b) + def sign(self, a): + return tnp.sign(a) + def dot(self, a, b): if len(b.shape) == 1: if len(a.shape) == 1: diff --git a/ot/da.py b/ot/da.py index 1437d459f..a442f620c 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1269,6 +1269,7 @@ class label Returns self. """ nx = self._get_backend(Xs, ys, Xt, yt) + self.nx = nx self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) @@ -1423,14 +1424,16 @@ class label self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) + self.nx = nx self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) # coupling estimation returned_ = empirical_gaussian_gromov_wasserstein_mapping(Xs, Xt, - ws=self.mu_s, - wt=self.mu_t, + ws=self.mu_s[:, None], + wt=self.mu_t[:, None], log=self.log) # deal with the value of log @@ -1442,8 +1445,8 @@ class label # re compute inverse mapping returned_1_ = empirical_gaussian_gromov_wasserstein_mapping(Xt, Xs, - ws=self.mu_t, - wt=self.mu_s, + ws=self.mu_t[:, None], + wt=self.mu_s[:, None], log=self.log) if self.log: self.A1_, self.B1_, self.log_1_ = returned_1_ diff --git a/ot/gaussian.py b/ot/gaussian.py index 8a3c1dcbf..e4d9ffb23 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -403,9 +403,9 @@ def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, samples in the source domain xt : array-like (nt,d) samples in the target domain - ws : array-like (ns), optional + ws : array-like (ns,1), optional weights for the source samples - wt : array-like (ns), optional + wt : array-like (ns,1), optional weights for the target samples log : bool, optional record log if True @@ -428,19 +428,19 @@ def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, nx = get_backend(xs, xt) if ws is None: - ws = nx.ones((xs.shape[0]), type_as=xs) / xs.shape[0] + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: - wt = nx.ones((xt.shape[0]), type_as=xt) / xt.shape[0] + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - mxs = nx.dot(ws, xs) / nx.sum(ws) - mxt = nx.dot(wt, xt) / nx.sum(wt) + mxs = nx.dot(ws.T, xs) / nx.sum(ws) + mxt = nx.dot(wt.T, xt) / nx.sum(wt) xs = xs - mxs xt = xt - mxt - Cs = nx.dot((xs * ws[:, None]).T, xs) / nx.sum(ws) - Ct = nx.dot((xt * wt[:, None]).T, xt) / nx.sum(wt) + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) if log: G, log = gaussian_gromov_wasserstein_distance(Cs, Ct, log=log) @@ -549,8 +549,10 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, weights for the source samples wt : array-like (ns,1), optional weights for the target samples - sign_eigs : array-like (min(ds,dt),), optional - sign of the eigenvalues of the mapping matrix + sign_eigs : array-like (min(ds,dt),) or string, optional + sign of the eigenvalues of the mapping matrix, by default all signs will + be positive. If 'skewness' is provided, the sign of the eigenvalues is + selected as the product of the sign of the skewness of the projected data. log : bool, optional record log if True @@ -571,30 +573,64 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, """ xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + m = xs.shape[1] + n = xt.shape[1] + if ws is None: - ws = nx.ones((xs.shape[0]), type_as=xs) / xs.shape[0] + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: - wt = nx.ones((xt.shape[0]), type_as=xt) / xt.shape[0] + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - mxs = nx.dot(ws, xs) / nx.sum(ws) - mxt = nx.dot(wt, xt) / nx.sum(wt) + # estimate mean and covariance + mu_s = nx.dot(ws.T, xs) / nx.sum(ws) + mu_t = nx.dot(wt.T, xt) / nx.sum(wt) - xs = xs - mxs - xt = xt - mxt + xs = xs - mu_s + xt = xt - mu_t - Cs = nx.dot((xs * ws[:, None]).T, xs) / nx.sum(ws) - Ct = nx.dot((xt * wt[:, None]).T, xt) / nx.sum(wt) + Cov_s = nx.dot((xs * ws).T, xs) / nx.sum(ws) + Cov_t = nx.dot((xt * wt).T, xt) / nx.sum(wt) - if log: + # compte and sort eigenvalues/eigenvectors decreasingly + d_s, U_s = nx.eigh(Cov_s) + id_s = nx.flip(nx.argsort(d_s)) + d_s, U_s = d_s[id_s], U_s[:, id_s] - A, b, log = gaussian_gromov_wasserstein_mapping(mxs, mxt, Cs, Ct, sign_eigs=sign_eigs, log=log) - log['Cov_s'] = Cs - log['Cov_t'] = Ct - return A, b, log + d_t, U_t = nx.eigh(Cov_t) + id_t = nx.flip(nx.argsort(d_t)) + d_t, U_t = d_t[id_t], U_t[:, id_t] + + # select the sign of the eigenvalues + if sign_eigs is None: + sign_eigs = nx.ones(min(m, n), type_as=mu_s) + elif sign_eigs == 'skewness': + size = min(m, n) + skew_s = nx.sum((nx.dot(xs, U_s[:, :size]))**3 * ws, axis=0) + skew_t = nx.sum((nx.dot(xt, U_t[:, :size]))**3 * wt, axis=0) + sign_eigs = nx.sign(skew_t * skew_s) + + if m >= n: + A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), nx.zeros((n, m - n), type_as=mu_s)), axis=1).T + else: + A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), nx.zeros((n - m, m), type_as=mu_s)), axis=0).T + + A = nx.dot(nx.dot(U_s, A), U_t.T) + # compute the gaussien Gromov-Wasserstein dis + b = mu_t - nx.dot(mu_s, A) + + if log: + log = {} + log['d_s'] = d_s + log['d_t'] = d_t + log['U_s'] = U_s + log['U_t'] = U_t + log['Cov_s'] = Cov_s + log['Cov_t'] = Cov_t + return A, b, log else: - A, b = gaussian_gromov_wasserstein_mapping(mxs, mxt, Cs, Ct, sign_eigs=sign_eigs) return A, b diff --git a/test/test_backend.py b/test/test_backend.py index b161746bf..f0571471c 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -371,6 +371,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('minimum') + A = nx.sign(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sign') + A = nx.abs(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('abs') diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 8e5ca048b..5a021d004 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -147,14 +147,13 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): A, b, log = ot.gaussian.gaussian_gromov_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) Ae, be, loge = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(Xsb, Xtb, log=True) - # no log - Ae0, be0 = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(Xsb, Xtb, log=False) + # no log + skewness + Ae0, be0 = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(Xsb, Xtb, log=False, sign_eigs='skewness') Xst = nx.to_numpy(nx.dot(Xsb, A) + b) Cst = np.cov(Xst.T) np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae)) - np.testing.assert_allclose(nx.to_numpy(A), nx.to_numpy(Ae0)) if d_target <= 2: np.testing.assert_allclose(Ct, Cst) From 1208297dae5f792d2dbcfe83b9291dee8d351c19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 13:34:00 +0200 Subject: [PATCH 10/13] add sign_eigs to DA class --- ot/da.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ot/da.py b/ot/da.py index a442f620c..f89aaf403 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1382,6 +1382,10 @@ class LinearGWTransport(LinearTransport): Parameters ---------- + sign_eigs : array-like (n_features), str, optional + sign of the eigenvalues of the mapping matrix, by default all signs will + be positive. If 'skewness' is provided, the sign of the eigenvalues is + selected as the product of the sign of the skewness of the projected data. log : bool, optional record log if True @@ -1395,8 +1399,9 @@ class LinearGWTransport(LinearTransport): """ - def __init__(self, log=False, + def __init__(self, log=False, sign_eigs=None, distribution_estimation=distribution_estimation_uniform): + self.sign_eigs = sign_eigs self.log = log self.distribution_estimation = distribution_estimation @@ -1434,6 +1439,7 @@ class label returned_ = empirical_gaussian_gromov_wasserstein_mapping(Xs, Xt, ws=self.mu_s[:, None], wt=self.mu_t[:, None], + sign_eigs=self.sign_eigs, log=self.log) # deal with the value of log @@ -1447,6 +1453,7 @@ class label returned_1_ = empirical_gaussian_gromov_wasserstein_mapping(Xt, Xs, ws=self.mu_t[:, None], wt=self.mu_s[:, None], + sign_eigs=self.sign_eigs, log=self.log) if self.log: self.A1_, self.B1_, self.log_1_ = returned_1_ From 78d5ec81791b3d8c3fee44604a4a70ed790aadc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 14:11:05 +0200 Subject: [PATCH 11/13] change in release file --- RELEASES.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index a81fdeeea..9c0df73b5 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,7 +5,7 @@ This new release contains several new features and bug fixes. -New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default). +New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default). Finally we added in the submodule `ot.gaussian` and `ot.da` new loss and mapping estimators for the Gaussian Gromov-Wasserstein that can be used as a fast alternative to GW and estimates linear mappings between unregistered spaces that can potentially have different size (See the update [linear mapping example](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) for an illustration). We also provide a new solver for the [Entropic Wasserstein Component Analysis](https://pythonot.github.io/master/auto_examples/others/plot_EWCA.html) that is a generalization of the celebrated PCA taking into account the local neighborhood of the samples. We also now have a new solver in `ot.smooth` for the [sparsity-constrained OT (last plot)](https://pythonot.github.io/master/auto_examples/plot_OT_1D_smooth.html) that can be used to find regularized OT plans with sparsity constraints. Finally we have a first multi-marginal solver for regular 1D distributions with a Monge loss (see [here](https://pythonot.github.io/master/auto_examples/others/plot_dmmot.html)). @@ -15,6 +15,7 @@ Many other bugs and issues have been fixed and we want to thank all the contribu #### New features +- Gaussian Gromov Wasserstein loss and mapping (PR #498) - Template-based Fused Gromov Wasserstein GNN layer in `ot.gnn` (PR #488) - Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483) - Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) From 38b22fbfebddda9a71b145b178e3d4eebc5f0b01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 14:45:00 +0200 Subject: [PATCH 12/13] debug code coverage --- .github/workflows/build_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index a5e876b96..492032ffc 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -40,7 +40,7 @@ jobs: pip install pytest pytest-cov - name: Run tests run: | - python -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --cov-report=xml + python -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --cov=./ --cov-report=xml - name: Upload coverage reports to Codecov with GitHub Action uses: codecov/codecov-action@v3 From bde1e193c316d3051c4597c469f4a311d0aa5625 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Aug 2023 15:46:45 +0200 Subject: [PATCH 13/13] documeation fix --- RELEASES.md | 2 +- ot/da.py | 2 +- ot/gaussian.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 9c0df73b5..d0209e233 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,7 +5,7 @@ This new release contains several new features and bug fixes. -New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default). Finally we added in the submodule `ot.gaussian` and `ot.da` new loss and mapping estimators for the Gaussian Gromov-Wasserstein that can be used as a fast alternative to GW and estimates linear mappings between unregistered spaces that can potentially have different size (See the update [linear mapping example](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) for an illustration). +New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), new solvers for entropic FGW barycenters, novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default). Finally we added in the submodule `ot.gaussian` and `ot.da` new loss and mapping estimators for the Gaussian Gromov-Wasserstein that can be used as a fast alternative to GW and estimates linear mappings between unregistered spaces that can potentially have different size (See the update [linear mapping example](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) for an illustration). We also provide a new solver for the [Entropic Wasserstein Component Analysis](https://pythonot.github.io/master/auto_examples/others/plot_EWCA.html) that is a generalization of the celebrated PCA taking into account the local neighborhood of the samples. We also now have a new solver in `ot.smooth` for the [sparsity-constrained OT (last plot)](https://pythonot.github.io/master/auto_examples/plot_OT_1D_smooth.html) that can be used to find regularized OT plans with sparsity constraints. Finally we have a first multi-marginal solver for regular 1D distributions with a Monge loss (see [here](https://pythonot.github.io/master/auto_examples/others/plot_dmmot.html)). diff --git a/ot/da.py b/ot/da.py index f89aaf403..5d55f53c7 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1365,7 +1365,7 @@ class LinearGWTransport(LinearTransport): r""" OT Gaussian Gromov-Wasserstein linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two - empirical distributions optimaly wrt the Gromov wassretsein distance. This is equivalent to estimating the closed + empirical distributions optimally wrt the Gromov-Wasserstein distance. This is equivalent to estimating the closed form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in :ref:`[57] `. diff --git a/ot/gaussian.py b/ot/gaussian.py index e4d9ffb23..e83d5eee8 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -42,9 +42,9 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): mean of the source distribution mt : array-like (d,) mean of the target distribution - Cs : array-like (d,) + Cs : array-like (d,d) covariance of the source distribution - Ct : array-like (d,) + Ct : array-like (d,d) covariance of the target distribution log : bool, optional record log if True @@ -210,9 +210,9 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): mean of the source distribution mt : array-like (d,) mean of the target distribution - Cs : array-like (d,) + Cs : array-like (d,d) covariance of the source distribution - Ct : array-like (d,) + Ct : array-like (d,d) covariance of the target distribution log : bool, optional record log if True @@ -344,9 +344,9 @@ def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): Parameters ---------- - Cov_s : array-like (d,d) + Cov_s : array-like (ds,ds) covariance of the source distribution - Cov_t : array-like (d,d) + Cov_t : array-like (dt,dt) covariance of the target distribution