Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent 200e5b6 commit 3d7a5c4
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,17 @@ def sample(
) -> torch.Tensor:
...

@property
def deterministic_sample(self):
return self.mode

@property
def mode(self) -> torch.Tensor:
if hasattr(self, "logits"):
return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
else:
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return super().log_prob(value.argmax(dim=-1))

Expand Down

0 comments on commit 3d7a5c4

Please sign in to comment.