From f3180e6e0c41c5d5f2c840befe6685618c0d250d Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 26 Dec 2023 11:17:23 -0500 Subject: [PATCH 1/2] Fix: input to env.proxy must be a batch, even if single state --- gflownet/utils/batch.py | 2 +- tests/gflownet/utils/test_batch.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index a58c1a7e7..59990cc5b 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -939,7 +939,7 @@ def _compute_rewards_source(self): """ # This will not work if source is randomised if not self.conditional: - source_proxy = self.env.state2proxy(self.env.source) + source_proxy = torch.unsqueeze(self.env.state2proxy(self.env.source), dim=0) reward_source = self.env.proxy2reward(self.env.proxy(source_proxy)) self.rewards_source = reward_source.expand(len(self)) else: diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 2b3956038..617f5c82c 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -1274,7 +1274,9 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( actions_iter.append(action) valids_iter.append(valid) rewards.append(env.reward(do_non_terminating=True)) - proxy_values.append(env.proxy(env.state2proxy(env.state))[0]) + proxy_values.append( + env.proxy(torch.unsqueeze(env.state2proxy(env.state), dim=0))[0] + ) # Add all envs, actions and valids to batch batch.add_to_batch(envs, actions_iter, valids_iter) # Remove done envs From bb09c7d35da8a369626599eaa9797f606f5a917e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 26 Dec 2023 12:04:44 -0500 Subject: [PATCH 2/2] Fix: make torch.equal torch.all(torch.isclose(... --- tests/gflownet/utils/test_batch.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 617f5c82c..9c25e0d0f 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -1354,12 +1354,16 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( rewards_batch = batch.get_rewards(do_non_terminating=True) rewards = torch.stack(rewards) - assert torch.equal( - rewards_parents_batch, - tfloat(rewards_parents, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_parents_batch, + tfloat(rewards_parents, device=batch.device, float_type=batch.float), + ) ), (rewards_parents, rewards_parents_batch) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch)