Skip to content

Commit

Permalink
upgrade black, reformat
Browse files Browse the repository at this point in the history
AlexandraVolokhova committed Aug 9, 2024
1 parent 9b5f1ce commit 8651fc3
Showing 2 changed files with 15 additions and 15 deletions.
24 changes: 12 additions & 12 deletions gflownet/policy/multihead_tree.py
Original file line number Diff line number Diff line change
@@ -357,9 +357,9 @@ def forward(self, x):
logits[indices, self.leaf_index : self.feature_index] = y_leaf
logits[indices, self.eos_index] = y_eos
elif stage == Stage.LEAF:
logits[
indices, self.feature_index : self.threshold_index
] = self.feature_head(batch)
logits[indices, self.feature_index : self.threshold_index] = (
self.feature_head(batch)
)
else:
ks = [Tree.find_active(state) for state in states]
feature_index = torch.Tensor(
@@ -374,9 +374,9 @@ def forward(self, x):
if self.continuous:
logits[indices, (self.eos_index + 1) :] = head_output
else:
logits[
indices, self.threshold_index : self.operator_index
] = head_output
logits[indices, self.threshold_index : self.operator_index] = (
head_output
)
elif stage == Stage.THRESHOLD:
threshold = torch.Tensor(
[
@@ -464,14 +464,14 @@ def forward(self, x):
)

if stage == Stage.COMPLETE:
logits[
indices, self.operator_index : self.eos_index
] = self.complete_stage_head(batch)
logits[indices, self.operator_index : self.eos_index] = (
self.complete_stage_head(batch)
)
logits[indices, self.eos_index] = 1.0
elif stage == Stage.LEAF:
logits[
indices, self.leaf_index : self.feature_index
] = self.leaf_stage_head(batch)
logits[indices, self.leaf_index : self.feature_index] = (
self.leaf_stage_head(batch)
)
elif stage == Stage.FEATURE:
logits[indices, self.feature_index : self.threshold_index] = 1.0
elif stage == Stage.THRESHOLD:
6 changes: 3 additions & 3 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
@@ -873,9 +873,9 @@ def get_masks_forward(
masks_invalid_actions_forward_parents[parents_indices == -1] = self.source[
"mask_forward"
]
masks_invalid_actions_forward_parents[
parents_indices != -1
] = masks_invalid_actions_forward[parents_indices[parents_indices != -1]]
masks_invalid_actions_forward_parents[parents_indices != -1] = (
masks_invalid_actions_forward[parents_indices[parents_indices != -1]]
)
return masks_invalid_actions_forward_parents
return masks_invalid_actions_forward

0 comments on commit 8651fc3

Please sign in to comment.