Skip to content

Commit

Permalink
Merge pull request #271 from alexhernandezgarcia/fix-merge-losses
Browse files Browse the repository at this point in the history
Fix issues from merge of PR #254
  • Loading branch information
alexhernandezgarcia authored Dec 26, 2023
2 parents 72b4713 + bb09c7d commit 8c13f6d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def _compute_rewards_source(self):
"""
# This will not work if source is randomised
if not self.conditional:
source_proxy = self.env.state2proxy(self.env.source)
source_proxy = torch.unsqueeze(self.env.state2proxy(self.env.source), dim=0)
reward_source = self.env.proxy2reward(self.env.proxy(source_proxy))
self.rewards_source = reward_source.expand(len(self))
else:
Expand Down
20 changes: 13 additions & 7 deletions tests/gflownet/utils/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,9 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating(
actions_iter.append(action)
valids_iter.append(valid)
rewards.append(env.reward(do_non_terminating=True))
proxy_values.append(env.proxy(env.state2proxy(env.state))[0])
proxy_values.append(
env.proxy(torch.unsqueeze(env.state2proxy(env.state), dim=0))[0]
)
# Add all envs, actions and valids to batch
batch.add_to_batch(envs, actions_iter, valids_iter)
# Remove done envs
Expand Down Expand Up @@ -1352,12 +1354,16 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating(
rewards_batch = batch.get_rewards(do_non_terminating=True)
rewards = torch.stack(rewards)

assert torch.equal(
rewards_parents_batch,
tfloat(rewards_parents, device=batch.device, float_type=batch.float),
assert torch.all(
torch.isclose(
rewards_parents_batch,
tfloat(rewards_parents, device=batch.device, float_type=batch.float),
)
), (rewards_parents, rewards_parents_batch)

assert torch.equal(
rewards_batch,
tfloat(rewards, device=batch.device, float_type=batch.float),
assert torch.all(
torch.isclose(
rewards_batch,
tfloat(rewards, device=batch.device, float_type=batch.float),
)
), (rewards, rewards_batch)

0 comments on commit 8c13f6d

Please sign in to comment.