Skip to content

Commit

Permalink
pep 8
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Nov 23, 2023
1 parent 2672db1 commit d700429
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
1 change: 1 addition & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from torchcfm.models import MLP
from torchcfm.models.unet import UNetModel

Expand Down
18 changes: 10 additions & 8 deletions tests/test_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,19 @@ def test_wasserstein(batch_size=128, seed=1980):
M = torch.cdist(x0.reshape(x0.shape[0], -1), x1.reshape(x1.shape[0], -1))
pot_W22 = ot.emd2(ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), (M**2).numpy())
pot_W2 = np.sqrt(pot_W22)
W2 = wasserstein(x0, x1, 'exact')
W2 = wasserstein(x0, x1, "exact")

pot_W1 = ot.emd2(ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy())
W1 = wasserstein(x0, x1, 'exact', power=1)
W1 = wasserstein(x0, x1, "exact", power=1)

pot_eot = ot.sinkhorn2(ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy(), reg=0.01, numItermax=int(1e7))
eot = wasserstein(x0, x1, 'sinkhorn', reg=0.01, power=1)
pot_eot = ot.sinkhorn2(
ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy(), reg=0.01, numItermax=int(1e7)
)
eot = wasserstein(x0, x1, "sinkhorn", reg=0.01, power=1)

with pytest.raises(ValueError) as excinfo:
eot = wasserstein(x0, x1, 'noname', reg=0.01, power=1)
eot = wasserstein(x0, x1, "noname", reg=0.01, power=1)

assert pot_W2==W2
assert pot_W1==W1
assert pot_eot==eot
assert pot_W2 == W2
assert pot_W1 == W1
assert pot_eot == eot

0 comments on commit d700429

Please sign in to comment.