Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 23, 2024
1 parent 0202562 commit f223205
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
16 changes: 6 additions & 10 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def __init__(
in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey],
out_keys: NestedKey | List[NestedKey] | None = None,
*,
default_interaction_mode: str | None = None,
default_interaction_type: InteractionType = InteractionType.DETERMINISTIC,
distribution_class: type = Delta,
distribution_kwargs: dict | None = None,
Expand Down Expand Up @@ -332,10 +331,6 @@ def __init__(
log_prob_key = "sample_log_prob"
self.log_prob_key = log_prob_key

if default_interaction_mode is not None:
raise ValueError(
"default_interaction_mode is deprecated, use default_interaction_type instead."
)
self.default_interaction_type = InteractionType(default_interaction_type)

if isinstance(distribution_class, str):
Expand All @@ -356,7 +351,7 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution:
for dist_key, td_key in _zip_strict(self.dist_keys, self.in_keys):
if isinstance(dist_key, tuple):
dist_key = dist_key[-1]
dist_kwargs[dist_key] = tensordict.get(td_key)
dist_kwargs[dist_key] = tensordict.get(td_key, None)
dist = self.distribution_class(
**dist_kwargs, **_dynamo_friendly_to_dict(self.distribution_kwargs)
)
Expand Down Expand Up @@ -630,12 +625,13 @@ def forward(
tensordict_out: TensorDictBase | None = None,
**kwargs,
) -> TensorDictBase:
if (tensordict_out is None and self._select_before_return) or (tensordict_out is not None):
if (tensordict_out is None and self._select_before_return) or (
tensordict_out is not None
):
tensordict_exec = tensordict.copy()
else:
tensordict_exec = tensordict
tensordict_exec = self.get_dist_params(tensordict_exec, tensordict_out, **kwargs
)
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
tensordict_exec = self.module[-1](
tensordict_exec, _requires_sample=self._requires_sample
)
Expand All @@ -645,7 +641,7 @@ def forward(
else:
result = tensordict_exec
if self._select_before_return:
return tensordict.update(result.select(*self.out_keys))
return tensordict.update(result, keys_to_update=self.out_keys)
return result


Expand Down
4 changes: 3 additions & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,9 @@ def forward(
tensordict_out: TensorDictBase | None = None,
**kwargs: Any,
) -> TensorDictBase:
if (tensordict_out is None and self._select_before_return) or (tensordict_out is not None):
if (tensordict_out is None and self._select_before_return) or (
tensordict_out is not None
):
tensordict_exec = tensordict.copy()
else:
tensordict_exec = tensordict
Expand Down
53 changes: 53 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,59 @@ def test_stateful_probabilistic_deprec(self, lazy):
dist = tdmodule.get_dist(td)
assert dist.rsample().shape[: td.ndimension()] == td.shape

@pytest.mark.parametrize("return_log_prob", [True, False])
@pytest.mark.parametrize("td_out", [True, False])
def test_probtdseq(self, return_log_prob, td_out):
mod = ProbabilisticTensorDictSequential(
TensorDictModule(lambda x: x + 2, in_keys=["a"], out_keys=["c"]),
TensorDictModule(lambda x: (x + 2, x), in_keys=["b"], out_keys=["d", "e"]),
ProbabilisticTensorDictModule(
in_keys={"loc": "d", "scale": "e"},
out_keys=["f"],
distribution_class=Normal,
return_log_prob=return_log_prob,
default_interaction_type="random",
),
)
inp = TensorDict({"a": 0.0, "b": 1.0})
inp_clone = inp.clone()
if td_out:
out = TensorDict()
else:
out = None
out2 = mod(inp, tensordict_out=out)
assert not mod._select_before_return
if td_out:
assert out is out2
else:
assert out2 is inp
assert set(out2.keys()) - {"a", "b"} == set(mod.out_keys), (
td_out,
return_log_prob,
)

inp = inp_clone.clone()
mod.select_out_keys("f")
if td_out:
out = TensorDict()
else:
out = None
out2 = mod(inp, tensordict_out=out)
assert mod._select_before_return
if td_out:
assert out is out2
else:
assert out2 is inp
expected = {"f"}
if td_out:
assert set(out2.keys()) == set(mod.out_keys) == expected
else:
assert (
set(out2.keys()) - set(inp_clone.keys())
== set(mod.out_keys)
== expected
)

@pytest.mark.parametrize("lazy", [True, False])
def test_stateful_probabilistic(self, lazy):
torch.manual_seed(0)
Expand Down

0 comments on commit f223205

Please sign in to comment.