From 7f9acf7a90928dd3df298f1d88bec9fcdf5bd771 Mon Sep 17 00:00:00 2001 From: Clement Date: Fri, 20 Sep 2024 11:11:16 +0200 Subject: [PATCH] test infinite reg sinkhorn unbalanced --- RELEASES.md | 1 + ot/unbalanced/_sinkhorn.py | 8 +++---- test/unbalanced/test_sinkhorn.py | 41 ++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 277af7847..452982a6e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -12,6 +12,7 @@ - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) - restructure `ot.unbalanced` module (PR #658) - add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) +- Added `sinkhorn_unbalanced_translation_invariant` in `ot.unbalanced.sinkhorn_unbalanced` (PR #676) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 1f926b946..8b9218a27 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -372,10 +372,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', log=log, **kwargs) elif method.lower() == 'sinkhorn_translation_invariant': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, c, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') diff --git a/test/unbalanced/test_sinkhorn.py b/test/unbalanced/test_sinkhorn.py index 7daca6414..6075f022b 100644 --- a/test/unbalanced/test_sinkhorn.py +++ b/test/unbalanced/test_sinkhorn.py @@ -288,6 +288,47 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m): nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) +@pytest.mark.parametrize("method,reg_m1, reg_m2", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], [1, float("inf")], [1, float("inf")])) +def test_unbalanced_relaxation_parameters_pair(nx, method, reg_m1, reg_m2): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + M = ot.dist(x, x) + epsilon = 1. + + a, b, M = nx.from_numpy(a, b, M) + + # options for reg_m + full_list_reg_m = [reg_m1, reg_m2] + full_tuple_reg_m = (reg_m1, reg_m2) + list_options = [full_tuple_reg_m, full_list_reg_m] + + loss, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=(reg_m1, reg_m2), + method=method, log=True, verbose=True + ) + + for opt in list_options: + loss_opt, log_opt = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=opt, + method=method, log=True, verbose=True + ) + + np.testing.assert_allclose( + nx.to_numpy(log["logu"]), nx.to_numpy(log_opt["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log["logv"]), nx.to_numpy(log_opt["logv"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) + + @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"]) def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT