From f3180e6e0c41c5d5f2c840befe6685618c0d250d Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 26 Dec 2023 11:17:23 -0500 Subject: [PATCH] Fix: input to env.proxy must be a batch, even if single state --- gflownet/utils/batch.py | 2 +- tests/gflownet/utils/test_batch.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index a58c1a7e7..59990cc5b 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -939,7 +939,7 @@ def _compute_rewards_source(self): """ # This will not work if source is randomised if not self.conditional: - source_proxy = self.env.state2proxy(self.env.source) + source_proxy = torch.unsqueeze(self.env.state2proxy(self.env.source), dim=0) reward_source = self.env.proxy2reward(self.env.proxy(source_proxy)) self.rewards_source = reward_source.expand(len(self)) else: diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 2b3956038..617f5c82c 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -1274,7 +1274,9 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( actions_iter.append(action) valids_iter.append(valid) rewards.append(env.reward(do_non_terminating=True)) - proxy_values.append(env.proxy(env.state2proxy(env.state))[0]) + proxy_values.append( + env.proxy(torch.unsqueeze(env.state2proxy(env.state), dim=0))[0] + ) # Add all envs, actions and valids to batch batch.add_to_batch(envs, actions_iter, valids_iter) # Remove done envs