diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index ecb90a73d..d28443831 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -295,6 +295,7 @@ def __init__( log_prob_key: Optional[NestedKey] = "sample_log_prob", cache_dist: bool = False, n_empirical_estimate: int = 1000, + num_samples: int | torch.Size | None = None, ) -> None: super().__init__() distribution_kwargs = ( @@ -346,6 +347,9 @@ def __init__( self._dist = None self.cache_dist = cache_dist if hasattr(distribution_class, "update") else False self.return_log_prob = return_log_prob + if isinstance(num_samples, (int, torch.SymInt)): + num_samples = torch.Size((num_samples,)) + self.num_samples = num_samples if self.return_log_prob and self.log_prob_key not in self.out_keys: self.out_keys.append(self.log_prob_key) @@ -464,6 +468,11 @@ def forward( dist = self.get_dist(tensordict) if _requires_sample: out_tensors = self._dist_sample(dist, interaction_type=interaction_type()) + if self.num_samples is not None: + # TODO: capture contiguous error here + tensordict_out = tensordict_out.expand( + self.num_samples + tensordict_out.shape + ) if isinstance(out_tensors, TensorDictBase): if self.return_log_prob: kwargs = {} @@ -571,10 +580,13 @@ def _dist_sample( return dist.sample((self.n_empirical_estimate,)).mean(0) elif interaction_type is InteractionType.RANDOM: + num_samples = self.num_samples + if num_samples is None: + num_samples = torch.Size(()) if dist.has_rsample: - return dist.rsample() + return dist.rsample(num_samples) else: - return dist.sample() + return dist.sample(num_samples) else: raise NotImplementedError(f"unknown interaction_type {interaction_type}") @@ -858,6 +870,16 @@ def get_dist_params( with set_interaction_type(type): return tds(tensordict, tensordict_out, **kwargs) + @property + def num_samples(self): + num_samples = () + for tdm in self.module: + if isinstance( + tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential) + ): + num_samples = tdm.num_samples + num_samples + return num_samples + def get_dist( self, tensordict: TensorDictBase, @@ -902,6 +924,8 @@ def get_dist( dist = tdm.get_dist(td_copy) if i < len(self.module) - 1: sample = tdm._dist_sample(dist, interaction_type=interaction_type()) + if tdm.num_samples not in ((), None): + td_copy = td_copy.expand(tdm.num_samples + td_copy.shape) if isinstance(tdm, ProbabilisticTensorDictModule): if isinstance(sample, torch.Tensor): sample = [sample] @@ -1119,7 +1143,10 @@ def forward( else: result = tensordict_exec if self._select_before_return: - return tensordict.update(result, keys_to_update=self.out_keys) + # We must also update any value that has been updated during the course of execution + # from the input data. + keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) + return tensordict.update(result, keys_to_update=keys) return result diff --git a/test/test_nn.py b/test/test_nn.py index ef82baba4..0d6085ef9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1866,100 +1866,225 @@ def test_module_buffer(): assert module.td.device.type == "cuda" -@pytest.mark.parametrize( - "log_prob_key", - [ - None, - "sample_log_prob", - ("nested", "sample_log_prob"), - ("data", "sample_log_prob"), - ], -) -def test_nested_keys_probabilistic_delta(log_prob_key): - policy_module = TensorDictModule( - nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] - ) - td = TensorDict({"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3]) - - module = ProbabilisticTensorDictModule( - in_keys=[("data", "param")], - out_keys=[("data", "action")], - distribution_class=Delta, - return_log_prob=True, - log_prob_key=log_prob_key, - ) - td_out = module(policy_module(td)) - assert td_out["data", "action"].shape == (3, 4, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4) - else: - assert td_out["sample_log_prob"].shape == (3, 4) - - module = ProbabilisticTensorDictModule( - in_keys={"param": ("data", "param")}, - out_keys=[("data", "action")], - distribution_class=Delta, - return_log_prob=True, - log_prob_key=log_prob_key, - ) - td_out = module(policy_module(td)) - assert td_out["data", "action"].shape == (3, 4, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4) - else: - assert td_out["sample_log_prob"].shape == (3, 4) +class TestProbabilisticTensorDictModule: + @pytest.mark.parametrize("return_log_prob", [True, False]) + def test_probabilistic_n_samples(self, return_log_prob): + prob = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["sample"], + distribution_class=Normal, + distribution_kwargs={"scale": 1}, + return_log_prob=return_log_prob, + num_samples=2, + default_interaction_type="random", + ) + # alone + td = TensorDict(loc=torch.randn(3, 4), batch_size=[3]) + td = prob(td) + assert "sample" in td + assert td.shape == (2, 3) + assert td["sample"].shape == (2, 3, 4) + if return_log_prob: + assert "sample_log_prob" in td + + @pytest.mark.parametrize("return_log_prob", [True, False]) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.parametrize("include_sum", [True, False]) + @pytest.mark.parametrize("aggregate_probabilities", [True, False]) + @pytest.mark.parametrize("return_composite", [True, False]) + def test_probabilistic_seq_n_samples( + self, + return_log_prob, + return_composite, + include_sum, + aggregate_probabilities, + inplace, + ): + prob = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["sample"], + distribution_class=Normal, + distribution_kwargs={"scale": 1}, + return_log_prob=return_log_prob, + num_samples=2, + default_interaction_type="random", + ) + # in a sequence + seq = ProbabilisticTensorDictSequential( + TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), + prob, + inplace=inplace, + aggregate_probabilities=aggregate_probabilities, + include_sum=include_sum, + return_composite=return_composite, + ) + td = TensorDict(x=torch.randn(3, 4), batch_size=[3]) + if return_composite: + assert isinstance(seq.get_dist(td), CompositeDistribution) + else: + assert isinstance(seq.get_dist(td), Normal) + td = seq(td) + assert "sample" in td + assert td.shape == (2, 3) + assert td["sample"].shape == (2, 3, 4) + if return_log_prob: + assert "sample_log_prob" in td + # log-prob from the sequence + log_prob = seq.log_prob(td) + if aggregate_probabilities or not return_composite: + assert isinstance(log_prob, torch.Tensor) + else: + assert isinstance(log_prob, TensorDict) -@pytest.mark.parametrize( - "log_prob_key", - [ - None, - "sample_log_prob", - ("nested", "sample_log_prob"), - ("data", "sample_log_prob"), - ], -) -def test_nested_keys_probabilistic_normal(log_prob_key): - loc_module = TensorDictModule( - nn.Linear(1, 1), - in_keys=[("data", "states")], - out_keys=[("data", "loc")], - ) - scale_module = TensorDictModule( - nn.Linear(1, 1), - in_keys=[("data", "states")], - out_keys=[("data", "scale")], - ) - td = TensorDict({"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3]) - - module = ProbabilisticTensorDictModule( - in_keys=[("data", "loc"), ("data", "scale")], - out_keys=[("data", "action")], - distribution_class=Normal, - return_log_prob=True, - log_prob_key=log_prob_key, + @pytest.mark.parametrize("return_log_prob", [True, False]) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.parametrize("include_sum", [True, False]) + @pytest.mark.parametrize("aggregate_probabilities", [True, False]) + @pytest.mark.parametrize("return_composite", [True]) + def test_intermediate_probabilistic_seq_n_samples( + self, + return_log_prob, + return_composite, + include_sum, + aggregate_probabilities, + inplace, + ): + prob = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["sample"], + distribution_class=Normal, + distribution_kwargs={"scale": 1}, + return_log_prob=return_log_prob, + num_samples=2, + default_interaction_type="random", + ) + + # intermediate in a sequence + seq = ProbabilisticTensorDictSequential( + TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), + prob, + TensorDictModule( + lambda x: x + 1, in_keys=["sample"], out_keys=["new_sample"] + ), + inplace=inplace, + aggregate_probabilities=aggregate_probabilities, + include_sum=include_sum, + return_composite=return_composite, + ) + td = TensorDict(x=torch.randn(3, 4), batch_size=[3]) + assert isinstance(seq.get_dist(td), CompositeDistribution) + td = seq(td) + assert "sample" in td + assert td.shape == (2, 3) + assert td["sample"].shape == (2, 3, 4) + if return_log_prob: + assert "sample_log_prob" in td + + # log-prob from the sequence + log_prob = seq.log_prob(td) + if aggregate_probabilities or not return_composite: + assert isinstance(log_prob, torch.Tensor) + else: + assert isinstance(log_prob, TensorDict) + + @pytest.mark.parametrize( + "log_prob_key", + [ + None, + "sample_log_prob", + ("nested", "sample_log_prob"), + ("data", "sample_log_prob"), + ], ) - with pytest.warns(UserWarning, match="deterministic_sample"): - td_out = module(loc_module(scale_module(td))) + def test_nested_keys_probabilistic_delta(self, log_prob_key): + policy_module = TensorDictModule( + nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] + ) + td = TensorDict( + {"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3] + ) + + module = ProbabilisticTensorDictModule( + in_keys=[("data", "param")], + out_keys=[("data", "action")], + distribution_class=Delta, + return_log_prob=True, + log_prob_key=log_prob_key, + ) + td_out = module(policy_module(td)) assert td_out["data", "action"].shape == (3, 4, 1) if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4, 1) + assert td_out[log_prob_key].shape == (3, 4) else: - assert td_out["sample_log_prob"].shape == (3, 4, 1) + assert td_out["sample_log_prob"].shape == (3, 4) module = ProbabilisticTensorDictModule( - in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")}, + in_keys={"param": ("data", "param")}, out_keys=[("data", "action")], - distribution_class=Normal, + distribution_class=Delta, return_log_prob=True, log_prob_key=log_prob_key, ) - td_out = module(loc_module(scale_module(td))) + td_out = module(policy_module(td)) assert td_out["data", "action"].shape == (3, 4, 1) if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4, 1) + assert td_out[log_prob_key].shape == (3, 4) else: - assert td_out["sample_log_prob"].shape == (3, 4, 1) + assert td_out["sample_log_prob"].shape == (3, 4) + + @pytest.mark.parametrize( + "log_prob_key", + [ + None, + "sample_log_prob", + ("nested", "sample_log_prob"), + ("data", "sample_log_prob"), + ], + ) + def test_nested_keys_probabilistic_normal(self, log_prob_key): + loc_module = TensorDictModule( + nn.Linear(1, 1), + in_keys=[("data", "states")], + out_keys=[("data", "loc")], + ) + scale_module = TensorDictModule( + nn.Linear(1, 1), + in_keys=[("data", "states")], + out_keys=[("data", "scale")], + ) + td = TensorDict( + {"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3] + ) + + module = ProbabilisticTensorDictModule( + in_keys=[("data", "loc"), ("data", "scale")], + out_keys=[("data", "action")], + distribution_class=Normal, + return_log_prob=True, + log_prob_key=log_prob_key, + ) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = module(loc_module(scale_module(td))) + assert td_out["data", "action"].shape == (3, 4, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (3, 4, 1) + else: + assert td_out["sample_log_prob"].shape == (3, 4, 1) + + module = ProbabilisticTensorDictModule( + in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")}, + out_keys=[("data", "action")], + distribution_class=Normal, + return_log_prob=True, + log_prob_key=log_prob_key, + ) + td_out = module(loc_module(scale_module(td))) + assert td_out["data", "action"].shape == (3, 4, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (3, 4, 1) + else: + assert td_out["sample_log_prob"].shape == (3, 4, 1) class TestEnsembleModule: