From bb09c7d35da8a369626599eaa9797f606f5a917e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 26 Dec 2023 12:04:44 -0500 Subject: [PATCH] Fix: make torch.equal torch.all(torch.isclose(... --- tests/gflownet/utils/test_batch.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 617f5c82c..9c25e0d0f 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -1354,12 +1354,16 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( rewards_batch = batch.get_rewards(do_non_terminating=True) rewards = torch.stack(rewards) - assert torch.equal( - rewards_parents_batch, - tfloat(rewards_parents, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_parents_batch, + tfloat(rewards_parents, device=batch.device, float_type=batch.float), + ) ), (rewards_parents, rewards_parents_batch) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch)