Skip to content

Commit

Permalink
add initial OT tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Nov 20, 2023
1 parent 6b3e2b6 commit 4a4032e
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/test_optimal_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Tests for Conditional Flow Matcher classers."""

# Author: Kilian Fatras <[email protected]>

import math

import numpy as np
import pytest
import torch
import ot

from torchcfm.optimal_transport import OTPlanSampler

ot_sampler = OTPlanSampler(method="exact")

def test_sample_map(batch_size=128):
# Build sparse random OT map
map = np.eye(batch_size)
rng = np.random.default_rng()
permuted_map = rng.permutation(map, axis=1)

# Sample elements from the OT plan
# All elements should be sampled only once
indices = ot_sampler.sample_map(permuted_map, batch_size=batch_size, replace=False)

# Reconstruct the coupling from the sampled elements
reconstructed_a = np.zeros((batch_size, batch_size))
for i in range(batch_size):
reconstructed_a[indices[0][i], indices[1][i]] = 1
assert np.array_equal(reconstructed_a, permuted_map)


def test_get_map(batch_size=128):
x0 = torch.randn(batch_size, 2, 2, 2)
x1 = torch.randn(batch_size, 2, 2, 2)

M = torch.cdist(x0.reshape(x0.shape[0], -1), x1.reshape(x1.shape[0], -1))**2
pot_pi = ot.emd(ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy())

pi = ot_sampler.get_map(x0, x1)

assert np.array_equal(pi, pot_pi)

0 comments on commit 4a4032e

Please sign in to comment.