Skip to content

Commit

Permalink
Reformat with black
Browse files Browse the repository at this point in the history
  • Loading branch information
atong01 committed Jul 11, 2024
1 parent 0b43559 commit cdbbef6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 3 additions & 1 deletion tests/test_conditional_flow_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def test_fm(method, sigma, shape, test_eps):
eps = None
if test_eps:
eps = torch.randn_like(x0)
t, xt, ut, ret_eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True, eps=eps)
t, xt, ut, ret_eps = FM.sample_location_and_conditional_flow(
x0, x1, return_noise=True, eps=eps
)
if test_eps:
assert torch.allclose(ret_eps, eps)
_ = FM.compute_lambda(t)
Expand Down
8 changes: 6 additions & 2 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def guided_sample_location_and_conditional_flow(
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)
t, xt, ut, eps = super().sample_location_and_conditional_flow(
x0, x1, t, return_noise, eps
)
return t, xt, ut, y0, y1, eps
else:
t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)
Expand Down Expand Up @@ -548,7 +550,9 @@ def guided_sample_location_and_conditional_flow(
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)
t, xt, ut, eps = super().sample_location_and_conditional_flow(
x0, x1, t, return_noise, eps
)
return t, xt, ut, y0, y1, eps
else:
t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise, eps)
Expand Down

0 comments on commit cdbbef6

Please sign in to comment.