Skip to content

Commit

Permalink
[Feature] Deterministic sample for Masked one-hot
Browse files Browse the repository at this point in the history
ghstack-source-id: 27787eab47324c5af152f706d81687e71b5b9803
Pull Request resolved: #2440
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent 0a410ff commit e294c68
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,17 @@ def sample(
) -> torch.Tensor:
...

@property
def deterministic_sample(self):
return self.mode

@property
def mode(self) -> torch.Tensor:
if hasattr(self, "logits"):
return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
else:
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return super().log_prob(value.argmax(dim=-1))

Expand Down

1 comment on commit e294c68

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: e294c68 Previous: 0a410ff Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 26.423393389039916 iter/sec (stddev: 0.17216425061165694) 76.28644808758818 iter/sec (stddev: 0.0012528477558623027) 2.89

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.