Skip to content

Commit

Permalink
add test sample_plan
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Nov 21, 2023
1 parent 4a4032e commit a040bb1
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions tests/test_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,61 @@
import math

import numpy as np
import ot
import pytest
import torch
import ot

from torchcfm.optimal_transport import OTPlanSampler

ot_sampler = OTPlanSampler(method="exact")


def test_sample_map(batch_size=128):
# Build sparse random OT map
# Build sparse random OT map
map = np.eye(batch_size)
rng = np.random.default_rng()
permuted_map = rng.permutation(map, axis=1)

# Sample elements from the OT plan
# All elements should be sampled only once
indices = ot_sampler.sample_map(permuted_map, batch_size=batch_size, replace=False)

# Reconstruct the coupling from the sampled elements
reconstructed_a = np.zeros((batch_size, batch_size))
reconstructed_map = np.zeros((batch_size, batch_size))
for i in range(batch_size):
reconstructed_a[indices[0][i], indices[1][i]] = 1
assert np.array_equal(reconstructed_a, permuted_map)
reconstructed_map[indices[0][i], indices[1][i]] = 1
assert np.array_equal(reconstructed_map, permuted_map)


def test_get_map(batch_size=128):
x0 = torch.randn(batch_size, 2, 2, 2)
x1 = torch.randn(batch_size, 2, 2, 2)

M = torch.cdist(x0.reshape(x0.shape[0], -1), x1.reshape(x1.shape[0], -1))**2
M = torch.cdist(x0.reshape(x0.shape[0], -1), x1.reshape(x1.shape[0], -1)) ** 2
pot_pi = ot.emd(ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy())

pi = ot_sampler.get_map(x0, x1)

assert np.array_equal(pi, pot_pi)


def test_sample_plan(batch_size=128, seed=1980):
torch.manual_seed(seed)
np.random.seed(seed)
x0 = torch.randn(batch_size, 2, 2, 2)
x1 = torch.randn(batch_size, 2, 2, 2)

pi = ot_sampler.get_map(x0, x1)
indices_i, indices_j = ot_sampler.sample_map(pi, batch_size=batch_size, replace=True)
new_x0, new_x1 = x0[indices_i], x1[indices_j]

torch.manual_seed(seed)
np.random.seed(seed)

sampled_x0, sampled_x1 = ot_sampler.sample_plan(x0, x1, replace=True)

assert torch.equal(new_x0, sampled_x0)
assert torch.equal(new_x1, sampled_x1)


test_sample_plan()

0 comments on commit a040bb1

Please sign in to comment.