diff --git a/ot/gaussian.py b/ot/gaussian.py index 0d99d9fae..eb38f0bc0 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -393,7 +393,9 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. """ - nx = get_backend(*C,) + nx = get_backend( + *C, + ) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] @@ -420,19 +422,21 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals if log: log = {} - log['num_iter'] = it - log['final_diff'] = diff + log["num_iter"] = it + log["final_diff"] = diff return Cb, log else: return Cb -def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None): +def bures_barycenter_gradient_descent( + C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None +): r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions. The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n` by using a gradient descent in the Wasserstein space :ref:`[74, 75] ` - on the objective + on the objective .. math:: \mathcal{L}(\Sigma) = \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0,\Sigma_i)\big). @@ -475,7 +479,9 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ - nx = get_backend(*C,) + nx = get_backend( + *C, + ) n = C.shape[0] @@ -492,9 +498,12 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, if batch_size is not None and batch_size < n: # if stochastic gradient descent if batch_size <= 0: - raise ValueError("batch_size must be an integer between 0 and {}".format(n)) - inds = np.random.choice(n, batch_size, replace=True, - p=nx._to_numpy(weights)) + raise ValueError( + "batch_size must be an integer between 0 and {}".format(n) + ) + inds = np.random.choice( + n, batch_size, replace=True, p=nx._to_numpy(weights) + ) M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C[inds], Cb12)) ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) grad_bw = Id - nx.mean(ot_maps, axis=0) @@ -503,7 +512,7 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_) grad_bw = Id - nx.sum(ot_maps * weights[:, None, None], axis=0) - Cnew = exp_bures(Cb, - step_size * grad_bw, nx=nx) + Cnew = exp_bures(Cb, -step_size * grad_bw, nx=nx) # check convergence if batch_size is not None and batch_size < n: @@ -522,16 +531,24 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, if log: log = {} - log['num_iter'] = it - log['final_diff'] = diff + log["num_iter"] = it + log["final_diff"] = diff return Cb, log else: return Cb -def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", - num_iter=1000, eps=1e-7, log=False, - step_size=1, batch_size=None): +def bures_wasserstein_barycenter( + m, + C, + weights=None, + method="fixed_point", + num_iter=1000, + eps=1e-7, + log=False, + step_size=1, + batch_size=None, +): r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions. The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\big(\mathcal{N}(\mu_i,\Sigma_i)\big)_{i=1}^n` @@ -601,7 +618,9 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145. """ - nx = get_backend(*m,) + nx = get_backend( + *m, + ) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] @@ -610,20 +629,24 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point", mb = nx.sum(m * weights[:, None], axis=0) if method == "gradient_descent" or batch_size is not None: - out = bures_barycenter_gradient_descent(C, weights=weights, - num_iter=num_iter, eps=eps, - log=log, step_size=step_size, - batch_size=batch_size) + out = bures_barycenter_gradient_descent( + C, + weights=weights, + num_iter=num_iter, + eps=eps, + log=log, + step_size=step_size, + batch_size=batch_size, + ) elif method == "fixed_point": - out = bures_barycenter_fixpoint(C, weights=weights, num_iter=num_iter, - eps=eps, log=log) + out = bures_barycenter_fixpoint( + C, weights=weights, num_iter=num_iter, eps=eps, log=log + ) else: raise ValueError("Unknown method '%s'." % method) if log: Cb, log = out - log["num_iter"] = it - log["final_diff"] = diff return mb, Cb, log else: Cb = out