Skip to content

Commit

Permalink
test infinite reg sinkhorn unbalanced
Browse files Browse the repository at this point in the history
  • Loading branch information
clbonet committed Sep 20, 2024
1 parent 2acbdbb commit 7f9acf7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649)
- restructure `ot.unbalanced` module (PR #658)
- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
- Added `sinkhorn_unbalanced_translation_invariant` in `ot.unbalanced.sinkhorn_unbalanced` (PR #676)

#### Closed issues
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
Expand Down
8 changes: 4 additions & 4 deletions ot/unbalanced/_sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
log=log, **kwargs)

elif method.lower() == 'sinkhorn_translation_invariant':
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
return sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, c,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)

elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp')
Expand Down
41 changes: 41 additions & 0 deletions test/unbalanced/test_sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,47 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m):
nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05)


@pytest.mark.parametrize("method,reg_m1, reg_m2", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], [1, float("inf")], [1, float("inf")]))
def test_unbalanced_relaxation_parameters_pair(nx, method, reg_m1, reg_m2):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(50)

x = rng.randn(n, 2)
a = ot.utils.unif(n)

# make dists unbalanced
b = rng.rand(n, 2)

M = ot.dist(x, x)
epsilon = 1.

a, b, M = nx.from_numpy(a, b, M)

# options for reg_m
full_list_reg_m = [reg_m1, reg_m2]
full_tuple_reg_m = (reg_m1, reg_m2)
list_options = [full_tuple_reg_m, full_list_reg_m]

loss, log = ot.unbalanced.sinkhorn_unbalanced(
a, b, M, reg=epsilon, reg_m=(reg_m1, reg_m2),
method=method, log=True, verbose=True
)

for opt in list_options:
loss_opt, log_opt = ot.unbalanced.sinkhorn_unbalanced(
a, b, M, reg=epsilon, reg_m=opt,
method=method, log=True, verbose=True
)

np.testing.assert_allclose(
nx.to_numpy(log["logu"]), nx.to_numpy(log_opt["logu"]), atol=1e-05)
np.testing.assert_allclose(
nx.to_numpy(log["logv"]), nx.to_numpy(log_opt["logv"]), atol=1e-05)
np.testing.assert_allclose(
nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05)


@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"])
def test_unbalanced_multiple_inputs(nx, method):
# test generalized sinkhorn for unbalanced OT
Expand Down

0 comments on commit 7f9acf7

Please sign in to comment.