diff --git a/tests/test_time_t.py b/tests/test_time_t.py index e4c5e46..d79ef7e 100644 --- a/tests/test_time_t.py +++ b/tests/test_time_t.py @@ -57,27 +57,15 @@ def test_guided_random_Tensor_t(FM, return_noise): x1 = torch.randn(batch_size, 2) y1 = torch.randint(high=10, size=(batch_size, 1)) - if return_noise: - torch.manual_seed(seed) - t_given = torch.rand(batch_size) - t_given, xt, ut, y0, y1, eps = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise - ) - - torch.manual_seed(seed) - t_random, xt, ut, y0, y1, eps = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise - ) - else: - torch.manual_seed(seed) - t_given = torch.rand(batch_size) - t_given, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=t_given - ) + torch.manual_seed(seed) + t_given = torch.rand(batch_size) + t_given = FM.guided_sample_location_and_conditional_flow( + x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise + )[0] - torch.manual_seed(seed) - t_random, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=None - ) + torch.manual_seed(seed) + t_random = FM.guided_sample_location_and_conditional_flow( + x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise + )[0] assert any(t_given == t_random)