Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Gaussian Gromov wasserstein solvers #498

Merged
merged 14 commits into from
Aug 8, 2023
2 changes: 1 addition & 1 deletion .github/workflows/build_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
pip install pytest pytest-cov
- name: Run tests
run: |
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --cov-report=xml
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --cov=./ --cov-report=xml
- name: Upload coverage reports to Codecov with GitHub Action
uses: codecov/codecov-action@v3

Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR).

[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019.

[57] Delon, J., Desolneux, A., & Salmona, A. (2022). [Gromov–Wasserstein
distances between Gaussian distributions](https://hal.science/hal-03197398v2/file/main.pdf). Journal of Applied Probability, 59(4),
1178-1198.

3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

This new release contains several new features and bug fixes.

New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default).
New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default). Finally we added in the submodule `ot.gaussian` and `ot.da` new loss and mapping estimators for the Gaussian Gromov-Wasserstein that can be used as a fast alternative to GW and estimates linear mappings between unregistered spaces that can potentially have different size (See the update [linear mapping example](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) for an illustration).
rflamary marked this conversation as resolved.
Show resolved Hide resolved

We also provide a new solver for the [Entropic Wasserstein Component Analysis](https://pythonot.github.io/master/auto_examples/others/plot_EWCA.html) that is a generalization of the celebrated PCA taking into account the local neighborhood of the samples. We also now have a new solver in `ot.smooth` for the [sparsity-constrained OT (last plot)](https://pythonot.github.io/master/auto_examples/plot_OT_1D_smooth.html) that can be used to find regularized OT plans with sparsity constraints. Finally we have a first multi-marginal solver for regular 1D distributions with a Monge loss (see [here](https://pythonot.github.io/master/auto_examples/others/plot_dmmot.html)).

Expand All @@ -15,6 +15,7 @@ Many other bugs and issues have been fixed and we want to thank all the contribu


#### New features
- Gaussian Gromov Wasserstein loss and mapping (PR #498)
- Template-based Fused Gromov Wasserstein GNN layer in `ot.gnn` (PR #488)
- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
Expand Down
62 changes: 51 additions & 11 deletions examples/domain-adaptation/plot_otda_linear_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

#%%
import os
from pathlib import Path

Expand Down Expand Up @@ -55,27 +57,43 @@
plt.figure(1, (5, 5))
plt.plot(xs[:, 0], xs[:, 1], '+')
plt.plot(xt[:, 0], xt[:, 1], 'o')

plt.legend(('Source', 'Target'))
plt.title('Source and target distributions')
plt.show()

##############################################################################
# Estimate linear mapping and transport
# -------------------------------------


# Gaussian (linear) Monge mapping estimation
Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt)

xst = xs.dot(Ae) + be

# Gaussian (linear) GW mapping estimation
Agw, bgw = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(xs, xt)

xstgw = xs.dot(Agw) + bgw

##############################################################################
# Plot transported samples
# ------------------------

plt.figure(1, (5, 5))
plt.figure(2, (10, 5))
plt.clf()
plt.subplot(1, 2, 1)
plt.plot(xs[:, 0], xs[:, 1], '+')
plt.plot(xt[:, 0], xt[:, 1], 'o')
plt.plot(xst[:, 0], xst[:, 1], '+')

plt.legend(('Source', 'Target', 'Transp. Monge'), loc=0)
plt.title('Transported samples with Monge')
plt.subplot(1, 2, 2)
plt.plot(xs[:, 0], xs[:, 1], '+')
plt.plot(xt[:, 0], xt[:, 1], 'o')
plt.plot(xstgw[:, 0], xstgw[:, 1], '+')
plt.legend(('Source', 'Target', 'Transp. GW'), loc=0)
plt.title('Transported samples with Gaussian GW')
plt.show()

##############################################################################
Expand Down Expand Up @@ -112,8 +130,8 @@ def minmax(img):
# Estimate mapping and adapt
# ----------------------------

# Monge mapping
mapping = ot.da.LinearTransport()

mapping.fit(Xs=X1, Xt=X2)


Expand All @@ -123,31 +141,53 @@ def minmax(img):
I1t = minmax(mat2im(xst, I1.shape))
I2t = minmax(mat2im(xts, I2.shape))

# gaussian GW mapping

mapping = ot.da.LinearGWTransport()
mapping.fit(Xs=X1, Xt=X2)


xstgw = mapping.transform(Xs=X1)
xtsgw = mapping.inverse_transform(Xt=X2)

I1tgw = minmax(mat2im(xstgw, I1.shape))
I2tgw = minmax(mat2im(xtsgw, I2.shape))

# %%


##############################################################################
# Plot transformed images
# -----------------------

plt.figure(2, figsize=(10, 7))
plt.figure(3, figsize=(14, 7))

plt.subplot(2, 2, 1)
plt.subplot(2, 3, 1)
plt.imshow(I1)
plt.axis('off')
plt.title('Im. 1')

plt.subplot(2, 2, 2)
plt.subplot(2, 3, 4)
plt.imshow(I2)
plt.axis('off')
plt.title('Im. 2')

plt.subplot(2, 2, 3)
plt.subplot(2, 3, 2)
plt.imshow(I1t)
plt.axis('off')
plt.title('Mapping Im. 1')
plt.title('Monge mapping Im. 1')

plt.subplot(2, 2, 4)
plt.subplot(2, 3, 5)
plt.imshow(I2t)
plt.axis('off')
plt.title('Inverse mapping Im. 2')
plt.title('Inverse Monge mapping Im. 2')

plt.subplot(2, 3, 3)
plt.imshow(I1tgw)
plt.axis('off')
plt.title('Gaussian GW mapping Im. 1')

plt.subplot(2, 3, 6)
plt.imshow(I2tgw)
plt.axis('off')
plt.title('Inverse Gaussian GW mapping Im. 2')
49 changes: 49 additions & 0 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,15 @@
"""
raise NotImplementedError()

def sign(self, a):
r""" Returns an element-wise indication of the sign of a number.

This function follows the api from :any:`numpy.sign`

See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html
"""
raise NotImplementedError()

Check warning on line 348 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L348

Added line #L348 was not covered by tests

def dot(self, a, b):
r"""
Returns the dot product of two tensors.
Expand Down Expand Up @@ -858,6 +867,16 @@
"""
raise NotImplementedError()

def eigh(self, a):
r"""
Computes the eigenvalues and eigenvectors of a symmetric tensor.

This function follows the api from :any:`scipy.linalg.eigh`.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html
"""
raise NotImplementedError()

Check warning on line 878 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L878

Added line #L878 was not covered by tests

def kl_div(self, p, q, eps=1e-16):
r"""
Computes the Kullback-Leibler divergence.
Expand Down Expand Up @@ -1047,6 +1066,9 @@
def minimum(self, a, b):
return np.minimum(a, b)

def sign(self, a):
return np.sign(a)

def dot(self, a, b):
return np.dot(a, b)

Expand Down Expand Up @@ -1253,6 +1275,9 @@
L, V = np.linalg.eigh(a)
return (V * np.sqrt(L)[None, :]) @ V.T

def eigh(self, a):
return np.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return np.sum(p * np.log(p / q + eps))

Expand Down Expand Up @@ -1415,6 +1440,9 @@
def minimum(self, a, b):
return jnp.minimum(a, b)

def sign(self, a):
return jnp.sign(a)

def dot(self, a, b):
return jnp.dot(a, b)

Expand Down Expand Up @@ -1631,6 +1659,9 @@
L, V = jnp.linalg.eigh(a)
return (V * jnp.sqrt(L)[None, :]) @ V.T

def eigh(self, a):
return jnp.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return jnp.sum(p * jnp.log(p / q + eps))

Expand Down Expand Up @@ -1829,6 +1860,9 @@
else:
return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]

def sign(self, a):
return torch.sign(a)

def dot(self, a, b):
return torch.matmul(a, b)

Expand Down Expand Up @@ -1922,16 +1956,16 @@
# Since version 1.11.0, interpolation is available
if version.parse(torch.__version__) >= version.parse("1.11.0"):
if axis is not None:
return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis)

Check warning on line 1959 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L1959

Added line #L1959 was not covered by tests
else:
return torch.quantile(a, 0.5, interpolation="midpoint")

# Else, use numpy
warnings.warn("The median is being computed using numpy and the array has been detached "

Check warning on line 1964 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L1964

Added line #L1964 was not covered by tests
"in the Pytorch backend.")
a_ = self.to_numpy(a)
a_median = np.median(a_, axis=axis)
return self.from_numpy(a_median, type_as=a)

Check warning on line 1968 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L1966-L1968

Added lines #L1966 - L1968 were not covered by tests

def std(self, a, axis=None):
if axis is not None:
Expand Down Expand Up @@ -2106,6 +2140,9 @@
L, V = torch.linalg.eigh(a)
return (V * torch.sqrt(L)[None, :]) @ V.T

def eigh(self, a):
return torch.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return torch.sum(p * torch.log(p / q + eps))

Expand Down Expand Up @@ -2248,6 +2285,9 @@
def minimum(self, a, b):
return cp.minimum(a, b)

def sign(self, a):
return cp.sign(a)

def abs(self, a):
return cp.abs(a)

Expand Down Expand Up @@ -2495,6 +2535,9 @@
L, V = cp.linalg.eigh(a)
return (V * cp.sqrt(L)[None, :]) @ V.T

def eigh(self, a):
return cp.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return cp.sum(p * cp.log(p / q + eps))

Expand Down Expand Up @@ -2539,7 +2582,7 @@

if cp:
# Only register cp backend if it is installed
register_backend(CupyBackend())

Check warning on line 2585 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L2585

Added line #L2585 was not covered by tests


class TensorflowBackend(Backend):
Expand Down Expand Up @@ -2642,6 +2685,9 @@
def minimum(self, a, b):
return tnp.minimum(a, b)

def sign(self, a):
return tnp.sign(a)

def dot(self, a, b):
if len(b.shape) == 1:
if len(a.shape) == 1:
Expand Down Expand Up @@ -2902,6 +2948,9 @@
L, V = tf.linalg.eigh(a)
return (V * tf.sqrt(L)[None, :]) @ V.T

def eigh(self, a):
return tf.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return tnp.sum(p * tnp.log(p / q + eps))

Expand Down
Loading
Loading