diff --git a/tests/test_conditional_flow_matcher.py b/tests/test_conditional_flow_matcher.py index 9c8767f..3b0d385 100644 --- a/tests/test_conditional_flow_matcher.py +++ b/tests/test_conditional_flow_matcher.py @@ -110,7 +110,9 @@ def test_fm(method, sigma, shape, test_eps): eps = None if test_eps: eps = torch.randn_like(x0) - t, xt, ut, ret_eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True, eps=eps) + t, xt, ut, ret_eps = FM.sample_location_and_conditional_flow( + x0, x1, return_noise=True, eps=eps + ) if test_eps: assert torch.allclose(ret_eps, eps) _ = FM.compute_lambda(t) diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index 1489286..a367440 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -309,7 +309,9 @@ def guided_sample_location_and_conditional_flow( """ x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1) if return_noise: - t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) + t, xt, ut, eps = super().sample_location_and_conditional_flow( + x0, x1, t, return_noise, eps + ) return t, xt, ut, y0, y1, eps else: t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) @@ -548,7 +550,9 @@ def guided_sample_location_and_conditional_flow( """ x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1) if return_noise: - t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps) + t, xt, ut, eps = super().sample_location_and_conditional_flow( + x0, x1, t, return_noise, eps + ) return t, xt, ut, y0, y1, eps else: t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)