Skip to content

Commit

Permalink
Fix: make torch.equal torch.all(torch.isclose(...
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Dec 26, 2023
1 parent f3180e6 commit bb09c7d
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions tests/gflownet/utils/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,12 +1354,16 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating(
rewards_batch = batch.get_rewards(do_non_terminating=True)
rewards = torch.stack(rewards)

assert torch.equal(
rewards_parents_batch,
tfloat(rewards_parents, device=batch.device, float_type=batch.float),
assert torch.all(
torch.isclose(
rewards_parents_batch,
tfloat(rewards_parents, device=batch.device, float_type=batch.float),
)
), (rewards_parents, rewards_parents_batch)

assert torch.equal(
rewards_batch,
tfloat(rewards, device=batch.device, float_type=batch.float),
assert torch.all(
torch.isclose(
rewards_batch,
tfloat(rewards, device=batch.device, float_type=batch.float),
)
), (rewards, rewards_batch)

0 comments on commit bb09c7d

Please sign in to comment.