Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 13, 2024
2 parents f9aa095 + 515b61c commit 98e3914
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 9 deletions.
68 changes: 67 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7565,6 +7565,7 @@ def _create_mock_actor(
"action1": (action_key, "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -7634,6 +7635,7 @@ def _create_mock_actor_value(
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -7690,6 +7692,7 @@ def _create_mock_actor_value_shared(
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8627,6 +8630,7 @@ def _create_mock_actor(
"action1": (action_key, "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8727,6 +8731,7 @@ def _create_mock_common_layer_setup(
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -15277,7 +15282,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
class MyLoss3(MyLoss2):
@dataclass
class _AcceptedKeys:
some_key = "some_value"
some_key: str = "some_value"

loss_module = MyLoss3()
assert loss_module.tensor_keys.some_key == "some_value"
Expand Down Expand Up @@ -15639,6 +15644,67 @@ def __init__(self):
assert p.device == dest


def test_exploration_compile():
m = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=torch.distributions.Normal,
)

# class set_exploration_type_random(set_exploration_type):
# __init__ = object.__init__
# type = ExplorationType.RANDOM
it = exploration_type()

@torch.compile(fullgraph=True)
def func(t):
with set_exploration_type(ExplorationType.RANDOM):
t0 = m(t.clone())
t1 = m(t.clone())
return t0, t1

t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
t0, t1 = func(t)
assert (t0["sample"] != t1["sample"]).any()
assert it == exploration_type()

@torch.compile(fullgraph=True)
def func(t):
with set_exploration_type(ExplorationType.MEAN):
t0 = m(t.clone())
t1 = m(t.clone())
return t0, t1

t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
t0, t1 = func(t)
assert (t0["sample"] == t1["sample"]).all()
assert it == exploration_type()

@torch.compile(fullgraph=True)
@set_exploration_type(ExplorationType.RANDOM)
def func(t):
t0 = m(t.clone())
t1 = m(t.clone())
return t0, t1

t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
t0, t1 = func(t)
assert (t0["sample"] != t1["sample"]).any()
assert it == exploration_type()

@torch.compile(fullgraph=True)
@set_exploration_type(ExplorationType.MEAN)
def func(t):
t0 = m(t.clone())
t1 = m(t.clone())
return t0, t1

t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
t0, t1 = func(t)
assert (t0["sample"] == t1["sample"]).all()
assert it == exploration_type()


def test_loss_exploration():
class DummyLoss(LossModule):
def forward(self, td, mode):
Expand Down
41 changes: 41 additions & 0 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import weakref
from warnings import warn

import torch

from tensordict import set_lazy_legacy

from torch import multiprocessing as mp
from torch.distributions.transforms import _InverseTransform, ComposeTransform

set_lazy_legacy(False).set()

Expand Down Expand Up @@ -51,3 +53,42 @@
filter_warnings_subprocess = True

_THREAD_POOL_INIT = torch.get_num_threads()

# monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home
@property
def inv(self):
"""
Returns the inverse :class:`Transform` of this transform.
This should satisfy ``t.inv.inv is t``.
"""
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv


torch.distributions.transforms.Transform.inv = inv


@property
def inv(self):
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
else:
# We need inv.inv to be equal to self, but weakref can cause a graph break
inv._inv = lambda out=self: out

return inv


ComposeTransform.inv = inv
9 changes: 7 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,13 @@ def _log_probs(
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
else:
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
maybe_log_prob = dist.log_prob(tensordict)
if not isinstance(maybe_log_prob, torch.Tensor):
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
# be a tensor
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
else:
log_prob = maybe_log_prob
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist

Expand Down
7 changes: 4 additions & 3 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def _forward_wrapper(func):
@functools.wraps(func)
def new_forward(self, *args, **kwargs):
with set_exploration_type(self.deterministic_sampling_mode):
# with nullcontext():
return func(self, *args, **kwargs)

return new_forward
Expand All @@ -55,7 +54,7 @@ def new_forward(self, *args, **kwargs):
class _LossMeta(abc.ABCMeta):
def __init__(cls, name, bases, attr_dict):
super().__init__(name, bases, attr_dict)
# cls.forward = _forward_wrapper(cls.forward)
cls.forward = _forward_wrapper(cls.forward)


class LossModule(TensorDictModuleBase, metaclass=_LossMeta):
Expand Down Expand Up @@ -229,7 +228,9 @@ def set_keys(self, **kwargs) -> None:
"""
for key, value in kwargs.items():
if key not in self._AcceptedKeys.__dataclass_fields__:
raise ValueError(f"{key} is not an accepted tensordict key")
raise ValueError(
f"{key} is not an accepted tensordict key. Accepted keys are: {self._AcceptedKeys.__dataclass_fields__}."
)
if value is not None:
setattr(self.tensor_keys, key, value)
else:
Expand Down
11 changes: 8 additions & 3 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,13 @@ def _log_weight(
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
else:
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
maybe_log_prob = dist.log_prob(tensordict)
if not isinstance(maybe_log_prob, torch.Tensor):
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
# be a tensor
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
else:
log_prob = maybe_log_prob

log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
Expand Down Expand Up @@ -1144,7 +1149,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
x = previous_dist.sample((self.samples_mc_kl,))
previous_log_prob = previous_dist.log_prob(x)
current_log_prob = current_dist.log_prob(x)
if is_tensor_collection(x):
if is_tensor_collection(current_log_prob):
previous_log_prob = previous_log_prob.get(
self.tensor_keys.sample_log_prob
)
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict_select = tensordict.select(
"next", *obs_keys, self.tensor_keys.action, strict=False
)
# We need to copy bc select does not copy sub-tds
tensordict_select = tensordict_select.copy()

selected_models_idx = torch.randperm(self.num_qvalue_nets)[
: self.sub_sample_len
Expand Down

0 comments on commit 98e3914

Please sign in to comment.