Skip to content

Commit

Permalink
Plumb eps through all methods. Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
atong01 committed Jul 10, 2024
1 parent c710150 commit 7c4984a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
19 changes: 13 additions & 6 deletions tests/test_conditional_flow_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def sample_plan(method, x0, x1, sigma):
# Test both integer and floating sigma
@pytest.mark.parametrize("sigma", [0.0, 5e-4, 0.5, 1.5, 0, 1])
@pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]])
def test_fm(method, sigma, shape):
@pytest.mark.parametrize("test_eps", [False, True])
def test_fm(method, sigma, shape, test_eps):
batch_size = TEST_BATCH_SIZE

if method in SIGMA_CONDITION.keys() and SIGMA_CONDITION[method](sigma):
Expand All @@ -106,7 +107,12 @@ def test_fm(method, sigma, shape):
x0, x1 = random_samples(shape, batch_size=batch_size)
torch.manual_seed(TEST_SEED)
np.random.seed(TEST_SEED)
t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True)
eps = None
if test_eps:
eps = torch.randn_like(x0)
t, xt, ut, ret_eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True, eps=eps)
if test_eps:
assert torch.allclose(ret_eps, eps)
_ = FM.compute_lambda(t)

if method in ["sb_cfm", "exact_ot_cfm"]:
Expand All @@ -115,13 +121,14 @@ def test_fm(method, sigma, shape):
x0, x1 = sample_plan(method, x0, x1, sigma)

torch.manual_seed(TEST_SEED)
if test_eps:
# compute to get same t seed
eps = torch.randn_like(x0)
t_given_init = torch.rand(batch_size)
t_given = t_given_init.reshape(-1, *([1] * (x0.dim() - 1)))
sigma_pad = pad_t_like_x(sigma, x0)
epsilon = torch.randn_like(x0)
computed_xt, computed_ut = compute_xt_ut(method, x0, x1, t_given, sigma_pad, epsilon)
computed_xt, computed_ut = compute_xt_ut(method, x0, x1, t_given, sigma_pad, ret_eps)

assert torch.all(ut.eq(computed_ut))
assert torch.all(xt.eq(computed_xt))
assert torch.all(eps.eq(epsilon))
assert torch.allclose(xt, computed_xt)
assert any(t_given_init == t)
47 changes: 27 additions & 20 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
del t, xt
return x1 - x0

def sample_noise_like(self, x):
return torch.randn_like(x)

def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_noise=False):
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False, eps=None):
"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Expand All @@ -169,8 +166,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise : bool
(optionally) return_noise : bool
return the noise sample epsilon
(optionally) eps: Tensor, shape (bs, *dim)
use a fixed noise vector epsilon
Returns
Expand All @@ -190,7 +189,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_
assert len(t) == x0.shape[0], "t has to have batch size dimension"

if eps is None:
eps = self.sample_noise_like(x0)
eps = torch.randn_like(x0)
xt = self.sample_xt(x0, x1, t, eps)
ut = self.compute_conditional_flow(x0, x1, t, xt)
if return_noise:
Expand Down Expand Up @@ -235,7 +234,7 @@ def __init__(self, sigma: Union[float, int] = 0.0):
super().__init__(sigma)
self.ot_sampler = OTPlanSampler(method="exact")

def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False, eps=None):
r"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]
Expand All @@ -250,8 +249,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise : bool
(optionally) return_noise : bool
return the noise sample epsilon
(optionally) eps: Tensor, shape (bs, *dim)
use a fixed noise vector epsilon
Returns
-------
Expand All @@ -266,10 +267,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
return super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)

def guided_sample_location_and_conditional_flow(
self, x0, x1, y0=None, y1=None, t=None, return_noise=False
self, x0, x1, y0=None, y1=None, t=None, return_noise=False, eps=None
):
r"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
Expand All @@ -289,8 +290,10 @@ def guided_sample_location_and_conditional_flow(
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise : bool
(optionally) return_noise : bool
return the noise sample epsilon
(optionally) eps: Tensor, shape (bs, *dim)
use a fixed noise vector epsilon
Returns
-------
Expand All @@ -306,10 +309,10 @@ def guided_sample_location_and_conditional_flow(
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)
return t, xt, ut, y0, y1, eps
else:
t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)
return t, xt, ut, y0, y1


Expand Down Expand Up @@ -469,7 +472,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
ut = sigma_t_prime_over_sigma_t * (xt - mu_t) + x1 - x0
return ut

def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False, eps=None):
"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sqrt(t * (1 - t))*sigma^2 ))
and the conditional vector field ut(x1|x0) = (1 - 2 * t) / (2 * t * (1 - t)) * (xt - mu_t) + x1 - x0,
Expand All @@ -484,8 +487,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise: bool
(optionally) return_noise: bool
return the noise sample epsilon
(optionally) eps: Tensor, shape (bs, *dim)
use a fixed noise vector epsilon
Returns
Expand All @@ -501,10 +506,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
return super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)

def guided_sample_location_and_conditional_flow(
self, x0, x1, y0=None, y1=None, t=None, return_noise=False
self, x0, x1, y0=None, y1=None, t=None, return_noise=False, eps=None
):
r"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
Expand All @@ -524,8 +529,10 @@ def guided_sample_location_and_conditional_flow(
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise : bool
(optionally) return_noise : bool
return the noise sample epsilon
(optionally) eps: Tensor, shape (bs, *dim)
use a fixed noise vector epsilon
Returns
-------
Expand All @@ -541,10 +548,10 @@ def guided_sample_location_and_conditional_flow(
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)
return t, xt, ut, y0, y1, eps
else:
t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)
return t, xt, ut, y0, y1


Expand Down

0 comments on commit 7c4984a

Please sign in to comment.