Skip to content

Commit

Permalink
[BugFix] Make probabilistic sequential modules compatible with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 9f1a3fc647a1976fc37c10bec7df1fdb8e5cbc08
Pull Request resolved: #1030
  • Loading branch information
vmoens committed Oct 4, 2024
1 parent 70f5888 commit 1f770b9
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,14 +600,12 @@ def __init__(
# distribution using `get_dist` or to sample log_probabilities
_, out_keys = self._compute_in_and_out_keys(modules[:-1])
self._requires_sample = modules[-1].out_keys[0] not in set(out_keys)
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])
super().__init__(*modules, partial_tolerant=partial_tolerant)

@property
def det_part(self):
if not hasattr(self, "_det_part"):
# we use a list to avoid having the submodules listed in module.modules()
self._det_part = [TensorDictSequential(*self.module[:-1])]
return self._det_part[0]
return self._det_part

def get_dist_params(
self,
Expand Down

0 comments on commit 1f770b9

Please sign in to comment.