Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 29, 2024
1 parent 7016a89 commit b0332b0
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 82 deletions.
33 changes: 30 additions & 3 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def __init__(
log_prob_key: Optional[NestedKey] = "sample_log_prob",
cache_dist: bool = False,
n_empirical_estimate: int = 1000,
num_samples: int | torch.Size | None = None,
) -> None:
super().__init__()
distribution_kwargs = (
Expand Down Expand Up @@ -346,6 +347,9 @@ def __init__(
self._dist = None
self.cache_dist = cache_dist if hasattr(distribution_class, "update") else False
self.return_log_prob = return_log_prob
if isinstance(num_samples, (int, torch.SymInt)):
num_samples = torch.Size((num_samples,))
self.num_samples = num_samples
if self.return_log_prob and self.log_prob_key not in self.out_keys:
self.out_keys.append(self.log_prob_key)

Expand Down Expand Up @@ -464,6 +468,11 @@ def forward(
dist = self.get_dist(tensordict)
if _requires_sample:
out_tensors = self._dist_sample(dist, interaction_type=interaction_type())
if self.num_samples is not None:
# TODO: capture contiguous error here
tensordict_out = tensordict_out.expand(
self.num_samples + tensordict_out.shape
)
if isinstance(out_tensors, TensorDictBase):
if self.return_log_prob:
kwargs = {}
Expand Down Expand Up @@ -571,10 +580,13 @@ def _dist_sample(
return dist.sample((self.n_empirical_estimate,)).mean(0)

elif interaction_type is InteractionType.RANDOM:
num_samples = self.num_samples
if num_samples is None:
num_samples = torch.Size(())
if dist.has_rsample:
return dist.rsample()
return dist.rsample(num_samples)
else:
return dist.sample()
return dist.sample(num_samples)
else:
raise NotImplementedError(f"unknown interaction_type {interaction_type}")

Expand Down Expand Up @@ -858,6 +870,16 @@ def get_dist_params(
with set_interaction_type(type):
return tds(tensordict, tensordict_out, **kwargs)

@property
def num_samples(self):
num_samples = ()
for tdm in self.module:
if isinstance(
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
):
num_samples = tdm.num_samples + num_samples
return num_samples

def get_dist(
self,
tensordict: TensorDictBase,
Expand Down Expand Up @@ -902,6 +924,8 @@ def get_dist(
dist = tdm.get_dist(td_copy)
if i < len(self.module) - 1:
sample = tdm._dist_sample(dist, interaction_type=interaction_type())
if tdm.num_samples not in ((), None):
td_copy = td_copy.expand(tdm.num_samples + td_copy.shape)
if isinstance(tdm, ProbabilisticTensorDictModule):
if isinstance(sample, torch.Tensor):
sample = [sample]
Expand Down Expand Up @@ -1119,7 +1143,10 @@ def forward(
else:
result = tensordict_exec
if self._select_before_return:
return tensordict.update(result, keys_to_update=self.out_keys)
# We must also update any value that has been updated during the course of execution
# from the input data.
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
return tensordict.update(result, keys_to_update=keys)
return result


Expand Down
283 changes: 204 additions & 79 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,100 +1866,225 @@ def test_module_buffer():
assert module.td.device.type == "cuda"


@pytest.mark.parametrize(
"log_prob_key",
[
None,
"sample_log_prob",
("nested", "sample_log_prob"),
("data", "sample_log_prob"),
],
)
def test_nested_keys_probabilistic_delta(log_prob_key):
policy_module = TensorDictModule(
nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")]
)
td = TensorDict({"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3])

module = ProbabilisticTensorDictModule(
in_keys=[("data", "param")],
out_keys=[("data", "action")],
distribution_class=Delta,
return_log_prob=True,
log_prob_key=log_prob_key,
)
td_out = module(policy_module(td))
assert td_out["data", "action"].shape == (3, 4, 1)
if log_prob_key:
assert td_out[log_prob_key].shape == (3, 4)
else:
assert td_out["sample_log_prob"].shape == (3, 4)

module = ProbabilisticTensorDictModule(
in_keys={"param": ("data", "param")},
out_keys=[("data", "action")],
distribution_class=Delta,
return_log_prob=True,
log_prob_key=log_prob_key,
)
td_out = module(policy_module(td))
assert td_out["data", "action"].shape == (3, 4, 1)
if log_prob_key:
assert td_out[log_prob_key].shape == (3, 4)
else:
assert td_out["sample_log_prob"].shape == (3, 4)
class TestProbabilisticTensorDictModule:
@pytest.mark.parametrize("return_log_prob", [True, False])
def test_probabilistic_n_samples(self, return_log_prob):
prob = ProbabilisticTensorDictModule(
in_keys=["loc"],
out_keys=["sample"],
distribution_class=Normal,
distribution_kwargs={"scale": 1},
return_log_prob=return_log_prob,
num_samples=2,
default_interaction_type="random",
)
# alone
td = TensorDict(loc=torch.randn(3, 4), batch_size=[3])
td = prob(td)
assert "sample" in td
assert td.shape == (2, 3)
assert td["sample"].shape == (2, 3, 4)
if return_log_prob:
assert "sample_log_prob" in td

@pytest.mark.parametrize("return_log_prob", [True, False])
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("include_sum", [True, False])
@pytest.mark.parametrize("aggregate_probabilities", [True, False])
@pytest.mark.parametrize("return_composite", [True, False])
def test_probabilistic_seq_n_samples(
self,
return_log_prob,
return_composite,
include_sum,
aggregate_probabilities,
inplace,
):
prob = ProbabilisticTensorDictModule(
in_keys=["loc"],
out_keys=["sample"],
distribution_class=Normal,
distribution_kwargs={"scale": 1},
return_log_prob=return_log_prob,
num_samples=2,
default_interaction_type="random",
)
# in a sequence
seq = ProbabilisticTensorDictSequential(
TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
prob,
inplace=inplace,
aggregate_probabilities=aggregate_probabilities,
include_sum=include_sum,
return_composite=return_composite,
)
td = TensorDict(x=torch.randn(3, 4), batch_size=[3])
if return_composite:
assert isinstance(seq.get_dist(td), CompositeDistribution)
else:
assert isinstance(seq.get_dist(td), Normal)
td = seq(td)
assert "sample" in td
assert td.shape == (2, 3)
assert td["sample"].shape == (2, 3, 4)
if return_log_prob:
assert "sample_log_prob" in td

# log-prob from the sequence
log_prob = seq.log_prob(td)
if aggregate_probabilities or not return_composite:
assert isinstance(log_prob, torch.Tensor)
else:
assert isinstance(log_prob, TensorDict)

@pytest.mark.parametrize(
"log_prob_key",
[
None,
"sample_log_prob",
("nested", "sample_log_prob"),
("data", "sample_log_prob"),
],
)
def test_nested_keys_probabilistic_normal(log_prob_key):
loc_module = TensorDictModule(
nn.Linear(1, 1),
in_keys=[("data", "states")],
out_keys=[("data", "loc")],
)
scale_module = TensorDictModule(
nn.Linear(1, 1),
in_keys=[("data", "states")],
out_keys=[("data", "scale")],
)
td = TensorDict({"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3])

module = ProbabilisticTensorDictModule(
in_keys=[("data", "loc"), ("data", "scale")],
out_keys=[("data", "action")],
distribution_class=Normal,
return_log_prob=True,
log_prob_key=log_prob_key,
@pytest.mark.parametrize("return_log_prob", [True, False])
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("include_sum", [True, False])
@pytest.mark.parametrize("aggregate_probabilities", [True, False])
@pytest.mark.parametrize("return_composite", [True])
def test_intermediate_probabilistic_seq_n_samples(
self,
return_log_prob,
return_composite,
include_sum,
aggregate_probabilities,
inplace,
):
prob = ProbabilisticTensorDictModule(
in_keys=["loc"],
out_keys=["sample"],
distribution_class=Normal,
distribution_kwargs={"scale": 1},
return_log_prob=return_log_prob,
num_samples=2,
default_interaction_type="random",
)

# intermediate in a sequence
seq = ProbabilisticTensorDictSequential(
TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
prob,
TensorDictModule(
lambda x: x + 1, in_keys=["sample"], out_keys=["new_sample"]
),
inplace=inplace,
aggregate_probabilities=aggregate_probabilities,
include_sum=include_sum,
return_composite=return_composite,
)
td = TensorDict(x=torch.randn(3, 4), batch_size=[3])
assert isinstance(seq.get_dist(td), CompositeDistribution)
td = seq(td)
assert "sample" in td
assert td.shape == (2, 3)
assert td["sample"].shape == (2, 3, 4)
if return_log_prob:
assert "sample_log_prob" in td

# log-prob from the sequence
log_prob = seq.log_prob(td)
if aggregate_probabilities or not return_composite:
assert isinstance(log_prob, torch.Tensor)
else:
assert isinstance(log_prob, TensorDict)

@pytest.mark.parametrize(
"log_prob_key",
[
None,
"sample_log_prob",
("nested", "sample_log_prob"),
("data", "sample_log_prob"),
],
)
with pytest.warns(UserWarning, match="deterministic_sample"):
td_out = module(loc_module(scale_module(td)))
def test_nested_keys_probabilistic_delta(self, log_prob_key):
policy_module = TensorDictModule(
nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")]
)
td = TensorDict(
{"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3]
)

module = ProbabilisticTensorDictModule(
in_keys=[("data", "param")],
out_keys=[("data", "action")],
distribution_class=Delta,
return_log_prob=True,
log_prob_key=log_prob_key,
)
td_out = module(policy_module(td))
assert td_out["data", "action"].shape == (3, 4, 1)
if log_prob_key:
assert td_out[log_prob_key].shape == (3, 4, 1)
assert td_out[log_prob_key].shape == (3, 4)
else:
assert td_out["sample_log_prob"].shape == (3, 4, 1)
assert td_out["sample_log_prob"].shape == (3, 4)

module = ProbabilisticTensorDictModule(
in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")},
in_keys={"param": ("data", "param")},
out_keys=[("data", "action")],
distribution_class=Normal,
distribution_class=Delta,
return_log_prob=True,
log_prob_key=log_prob_key,
)
td_out = module(loc_module(scale_module(td)))
td_out = module(policy_module(td))
assert td_out["data", "action"].shape == (3, 4, 1)
if log_prob_key:
assert td_out[log_prob_key].shape == (3, 4, 1)
assert td_out[log_prob_key].shape == (3, 4)
else:
assert td_out["sample_log_prob"].shape == (3, 4, 1)
assert td_out["sample_log_prob"].shape == (3, 4)

@pytest.mark.parametrize(
"log_prob_key",
[
None,
"sample_log_prob",
("nested", "sample_log_prob"),
("data", "sample_log_prob"),
],
)
def test_nested_keys_probabilistic_normal(self, log_prob_key):
loc_module = TensorDictModule(
nn.Linear(1, 1),
in_keys=[("data", "states")],
out_keys=[("data", "loc")],
)
scale_module = TensorDictModule(
nn.Linear(1, 1),
in_keys=[("data", "states")],
out_keys=[("data", "scale")],
)
td = TensorDict(
{"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3]
)

module = ProbabilisticTensorDictModule(
in_keys=[("data", "loc"), ("data", "scale")],
out_keys=[("data", "action")],
distribution_class=Normal,
return_log_prob=True,
log_prob_key=log_prob_key,
)
with pytest.warns(UserWarning, match="deterministic_sample"):
td_out = module(loc_module(scale_module(td)))
assert td_out["data", "action"].shape == (3, 4, 1)
if log_prob_key:
assert td_out[log_prob_key].shape == (3, 4, 1)
else:
assert td_out["sample_log_prob"].shape == (3, 4, 1)

module = ProbabilisticTensorDictModule(
in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")},
out_keys=[("data", "action")],
distribution_class=Normal,
return_log_prob=True,
log_prob_key=log_prob_key,
)
td_out = module(loc_module(scale_module(td)))
assert td_out["data", "action"].shape == (3, 4, 1)
if log_prob_key:
assert td_out[log_prob_key].shape == (3, 4, 1)
else:
assert td_out["sample_log_prob"].shape == (3, 4, 1)


class TestEnsembleModule:
Expand Down

0 comments on commit b0332b0

Please sign in to comment.