Skip to content

Commit

Permalink
Cleanup test
Browse files Browse the repository at this point in the history
  • Loading branch information
atong01 committed Nov 30, 2023
1 parent ea6e5d0 commit d0202a8
Showing 1 changed file with 9 additions and 21 deletions.
30 changes: 9 additions & 21 deletions tests/test_time_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,15 @@ def test_guided_random_Tensor_t(FM, return_noise):
x1 = torch.randn(batch_size, 2)
y1 = torch.randint(high=10, size=(batch_size, 1))

if return_noise:
torch.manual_seed(seed)
t_given = torch.rand(batch_size)
t_given, xt, ut, y0, y1, eps = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise
)

torch.manual_seed(seed)
t_random, xt, ut, y0, y1, eps = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise
)
else:
torch.manual_seed(seed)
t_given = torch.rand(batch_size)
t_given, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=t_given
)
torch.manual_seed(seed)
t_given = torch.rand(batch_size)
t_given = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise
)[0]

torch.manual_seed(seed)
t_random, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=None
)
torch.manual_seed(seed)
t_random = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise
)[0]

assert any(t_given == t_random)

0 comments on commit d0202a8

Please sign in to comment.