Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
clbonet committed Oct 31, 2024
1 parent 9377405 commit 4f648bb
Showing 1 changed file with 47 additions and 24 deletions.
71 changes: 47 additions & 24 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,9 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
2011.
"""
nx = get_backend(*C,)
nx = get_backend(
*C,
)

if weights is None:
weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0]
Expand All @@ -420,19 +422,21 @@ def bures_barycenter_fixpoint(C, weights=None, num_iter=1000, eps=1e-7, log=Fals

if log:
log = {}
log['num_iter'] = it
log['final_diff'] = diff
log["num_iter"] = it
log["final_diff"] = diff
return Cb, log
else:
return Cb


def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None):
def bures_barycenter_gradient_descent(
C, weights=None, num_iter=1000, eps=1e-7, log=False, step_size=1, batch_size=None
):
r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.
The function estimates the (Bures)-Wasserstein barycenter between centered Gaussian distributions :math:`\big(\mathcal{N}(0,\Sigma_i)\big)_{i=1}^n`
by using a gradient descent in the Wasserstein space :ref:`[74, 75] <references-OT-bures-barycenter-gradient_descent>`
on the objective
on the objective
.. math::
\mathcal{L}(\Sigma) = \sum_{i=1}^n w_i W_2^2\big(\mathcal{N}(0,\Sigma), \mathcal{N}(0,\Sigma_i)\big).
Expand Down Expand Up @@ -475,7 +479,9 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7,
Averaging on the Bures-Wasserstein manifold: dimension-free convergence
of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145.
"""
nx = get_backend(*C,)
nx = get_backend(
*C,
)

n = C.shape[0]

Expand All @@ -492,9 +498,12 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7,

if batch_size is not None and batch_size < n: # if stochastic gradient descent
if batch_size <= 0:
raise ValueError("batch_size must be an integer between 0 and {}".format(n))
inds = np.random.choice(n, batch_size, replace=True,
p=nx._to_numpy(weights))
raise ValueError(

Check warning on line 501 in ot/gaussian.py

View check run for this annotation

Codecov / codecov/patch

ot/gaussian.py#L500-L501

Added lines #L500 - L501 were not covered by tests
"batch_size must be an integer between 0 and {}".format(n)
)
inds = np.random.choice(

Check warning on line 504 in ot/gaussian.py

View check run for this annotation

Codecov / codecov/patch

ot/gaussian.py#L504

Added line #L504 was not covered by tests
n, batch_size, replace=True, p=nx._to_numpy(weights)
)
M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C[inds], Cb12))
ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_)
grad_bw = Id - nx.mean(ot_maps, axis=0)

Check warning on line 509 in ot/gaussian.py

View check run for this annotation

Codecov / codecov/patch

ot/gaussian.py#L507-L509

Added lines #L507 - L509 were not covered by tests
Expand All @@ -503,7 +512,7 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7,
ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_)
grad_bw = Id - nx.sum(ot_maps * weights[:, None, None], axis=0)

Cnew = exp_bures(Cb, - step_size * grad_bw, nx=nx)
Cnew = exp_bures(Cb, -step_size * grad_bw, nx=nx)

# check convergence
if batch_size is not None and batch_size < n:
Expand All @@ -522,16 +531,24 @@ def bures_barycenter_gradient_descent(C, weights=None, num_iter=1000, eps=1e-7,

if log:
log = {}
log['num_iter'] = it
log['final_diff'] = diff
log["num_iter"] = it
log["final_diff"] = diff
return Cb, log
else:
return Cb


def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point",
num_iter=1000, eps=1e-7, log=False,
step_size=1, batch_size=None):
def bures_wasserstein_barycenter(
m,
C,
weights=None,
method="fixed_point",
num_iter=1000,
eps=1e-7,
log=False,
step_size=1,
batch_size=None,
):
r"""Return the (Bures-)Wasserstein barycenter between Gaussian distributions.
The function estimates the (Bures)-Wasserstein barycenter between Gaussian distributions :math:`\big(\mathcal{N}(\mu_i,\Sigma_i)\big)_{i=1}^n`
Expand Down Expand Up @@ -601,7 +618,9 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point",
Averaging on the Bures-Wasserstein manifold: dimension-free convergence
of gradient descent. Advances in Neural Information Processing Systems, 34, 22132-22145.
"""
nx = get_backend(*m,)
nx = get_backend(
*m,
)

if weights is None:
weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0]
Expand All @@ -610,20 +629,24 @@ def bures_wasserstein_barycenter(m, C, weights=None, method="fixed_point",
mb = nx.sum(m * weights[:, None], axis=0)

if method == "gradient_descent" or batch_size is not None:
out = bures_barycenter_gradient_descent(C, weights=weights,
num_iter=num_iter, eps=eps,
log=log, step_size=step_size,
batch_size=batch_size)
out = bures_barycenter_gradient_descent(
C,
weights=weights,
num_iter=num_iter,
eps=eps,
log=log,
step_size=step_size,
batch_size=batch_size,
)
elif method == "fixed_point":
out = bures_barycenter_fixpoint(C, weights=weights, num_iter=num_iter,
eps=eps, log=log)
out = bures_barycenter_fixpoint(
C, weights=weights, num_iter=num_iter, eps=eps, log=log
)
else:
raise ValueError("Unknown method '%s'." % method)

Check warning on line 646 in ot/gaussian.py

View check run for this annotation

Codecov / codecov/patch

ot/gaussian.py#L646

Added line #L646 was not covered by tests

if log:
Cb, log = out
log["num_iter"] = it
log["final_diff"] = diff
return mb, Cb, log
else:
Cb = out
Expand Down

0 comments on commit 4f648bb

Please sign in to comment.