diff --git a/tests/test_conditional_flow_matcher.py b/tests/test_conditional_flow_matcher.py index b080470..9c8767f 100644 --- a/tests/test_conditional_flow_matcher.py +++ b/tests/test_conditional_flow_matcher.py @@ -94,7 +94,8 @@ def sample_plan(method, x0, x1, sigma): # Test both integer and floating sigma @pytest.mark.parametrize("sigma", [0.0, 5e-4, 0.5, 1.5, 0, 1]) @pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]]) -def test_fm(method, sigma, shape): +@pytest.mark.parametrize("test_eps", [False, True]) +def test_fm(method, sigma, shape, test_eps): batch_size = TEST_BATCH_SIZE if method in SIGMA_CONDITION.keys() and SIGMA_CONDITION[method](sigma): @@ -106,7 +107,12 @@ def test_fm(method, sigma, shape): x0, x1 = random_samples(shape, batch_size=batch_size) torch.manual_seed(TEST_SEED) np.random.seed(TEST_SEED) - t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True) + eps = None + if test_eps: + eps = torch.randn_like(x0) + t, xt, ut, ret_eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True, eps=eps) + if test_eps: + assert torch.allclose(ret_eps, eps) _ = FM.compute_lambda(t) if method in ["sb_cfm", "exact_ot_cfm"]: @@ -115,13 +121,14 @@ def test_fm(method, sigma, shape): x0, x1 = sample_plan(method, x0, x1, sigma) torch.manual_seed(TEST_SEED) + if test_eps: + # compute to get same t seed + eps = torch.randn_like(x0) t_given_init = torch.rand(batch_size) t_given = t_given_init.reshape(-1, *([1] * (x0.dim() - 1))) sigma_pad = pad_t_like_x(sigma, x0) - epsilon = torch.randn_like(x0) - computed_xt, computed_ut = compute_xt_ut(method, x0, x1, t_given, sigma_pad, epsilon) + computed_xt, computed_ut = compute_xt_ut(method, x0, x1, t_given, sigma_pad, ret_eps) assert torch.all(ut.eq(computed_ut)) - assert torch.all(xt.eq(computed_xt)) - assert torch.all(eps.eq(epsilon)) + assert torch.allclose(xt, computed_xt) assert any(t_given_init == t) diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index 9605081..1489286 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -152,10 +152,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): del t, xt return x1 - x0 - def sample_noise_like(self, x): - return torch.randn_like(x) - - def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_noise=False): + def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False, eps=None): """ Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. @@ -169,8 +166,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_ (optionally) t : Tensor, shape (bs) represents the time levels if None, drawn from uniform [0,1] - return_noise : bool + (optionally) return_noise : bool return the noise sample epsilon + (optionally) eps: Tensor, shape (bs, *dim) + use a fixed noise vector epsilon Returns @@ -190,7 +189,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_ assert len(t) == x0.shape[0], "t has to have batch size dimension" if eps is None: - eps = self.sample_noise_like(x0) + eps = torch.randn_like(x0) xt = self.sample_xt(x0, x1, t, eps) ut = self.compute_conditional_flow(x0, x1, t, xt) if return_noise: @@ -235,7 +234,7 @@ def __init__(self, sigma: Union[float, int] = 0.0): super().__init__(sigma) self.ot_sampler = OTPlanSampler(method="exact") - def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): + def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False, eps=None): r""" Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1] @@ -250,8 +249,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals (optionally) t : Tensor, shape (bs) represents the time levels if None, drawn from uniform [0,1] - return_noise : bool + (optionally) return_noise : bool return the noise sample epsilon + (optionally) eps: Tensor, shape (bs, *dim) + use a fixed noise vector epsilon Returns ------- @@ -266,10 +267,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. """ x0, x1 = self.ot_sampler.sample_plan(x0, x1) - return super().sample_location_and_conditional_flow(x0, x1, t, return_noise) + return super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) def guided_sample_location_and_conditional_flow( - self, x0, x1, y0=None, y1=None, t=None, return_noise=False + self, x0, x1, y0=None, y1=None, t=None, return_noise=False, eps=None ): r""" Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) @@ -289,8 +290,10 @@ def guided_sample_location_and_conditional_flow( (optionally) t : Tensor, shape (bs) represents the time levels if None, drawn from uniform [0,1] - return_noise : bool + (optionally) return_noise : bool return the noise sample epsilon + (optionally) eps: Tensor, shape (bs, *dim) + use a fixed noise vector epsilon Returns ------- @@ -306,10 +309,10 @@ def guided_sample_location_and_conditional_flow( """ x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1) if return_noise: - t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise) + t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) return t, xt, ut, y0, y1, eps else: - t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise) + t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) return t, xt, ut, y0, y1 @@ -469,7 +472,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): ut = sigma_t_prime_over_sigma_t * (xt - mu_t) + x1 - x0 return ut - def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): + def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False, eps=None): """ Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sqrt(t * (1 - t))*sigma^2 )) and the conditional vector field ut(x1|x0) = (1 - 2 * t) / (2 * t * (1 - t)) * (xt - mu_t) + x1 - x0, @@ -484,8 +487,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals (optionally) t : Tensor, shape (bs) represents the time levels if None, drawn from uniform [0,1] - return_noise: bool + (optionally) return_noise: bool return the noise sample epsilon + (optionally) eps: Tensor, shape (bs, *dim) + use a fixed noise vector epsilon Returns @@ -501,10 +506,10 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. """ x0, x1 = self.ot_sampler.sample_plan(x0, x1) - return super().sample_location_and_conditional_flow(x0, x1, t, return_noise) + return super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) def guided_sample_location_and_conditional_flow( - self, x0, x1, y0=None, y1=None, t=None, return_noise=False + self, x0, x1, y0=None, y1=None, t=None, return_noise=False, eps=None ): r""" Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) @@ -524,8 +529,10 @@ def guided_sample_location_and_conditional_flow( (optionally) t : Tensor, shape (bs) represents the time levels if None, drawn from uniform [0,1] - return_noise : bool + (optionally) return_noise : bool return the noise sample epsilon + (optionally) eps: Tensor, shape (bs, *dim) + use a fixed noise vector epsilon Returns ------- @@ -541,10 +548,10 @@ def guided_sample_location_and_conditional_flow( """ x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1) if return_noise: - t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise) + t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) return t, xt, ut, y0, y1, eps else: - t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise) + t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) return t, xt, ut, y0, y1