Skip to content

Commit

Permalink
Fix: input to env.proxy must be a batch, even if single state
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Dec 26, 2023
1 parent 72b4713 commit f3180e6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def _compute_rewards_source(self):
"""
# This will not work if source is randomised
if not self.conditional:
source_proxy = self.env.state2proxy(self.env.source)
source_proxy = torch.unsqueeze(self.env.state2proxy(self.env.source), dim=0)
reward_source = self.env.proxy2reward(self.env.proxy(source_proxy))
self.rewards_source = reward_source.expand(len(self))
else:
Expand Down
4 changes: 3 additions & 1 deletion tests/gflownet/utils/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,9 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating(
actions_iter.append(action)
valids_iter.append(valid)
rewards.append(env.reward(do_non_terminating=True))
proxy_values.append(env.proxy(env.state2proxy(env.state))[0])
proxy_values.append(
env.proxy(torch.unsqueeze(env.state2proxy(env.state), dim=0))[0]
)
# Add all envs, actions and valids to batch
batch.add_to_batch(envs, actions_iter, valids_iter)
# Remove done envs
Expand Down

0 comments on commit f3180e6

Please sign in to comment.