diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index a58c1a7e7..59990cc5b 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -939,7 +939,7 @@ def _compute_rewards_source(self): """ # This will not work if source is randomised if not self.conditional: - source_proxy = self.env.state2proxy(self.env.source) + source_proxy = torch.unsqueeze(self.env.state2proxy(self.env.source), dim=0) reward_source = self.env.proxy2reward(self.env.proxy(source_proxy)) self.rewards_source = reward_source.expand(len(self)) else: diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 2b3956038..9c25e0d0f 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -1274,7 +1274,9 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( actions_iter.append(action) valids_iter.append(valid) rewards.append(env.reward(do_non_terminating=True)) - proxy_values.append(env.proxy(env.state2proxy(env.state))[0]) + proxy_values.append( + env.proxy(torch.unsqueeze(env.state2proxy(env.state), dim=0))[0] + ) # Add all envs, actions and valids to batch batch.add_to_batch(envs, actions_iter, valids_iter) # Remove done envs @@ -1352,12 +1354,16 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( rewards_batch = batch.get_rewards(do_non_terminating=True) rewards = torch.stack(rewards) - assert torch.equal( - rewards_parents_batch, - tfloat(rewards_parents, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_parents_batch, + tfloat(rewards_parents, device=batch.device, float_type=batch.float), + ) ), (rewards_parents, rewards_parents_batch) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch)