Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 10, 2025
2 parents c7c9021 + d6ca42f commit 5c03f9f
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 30 deletions.
19 changes: 11 additions & 8 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ The ``"_reset"`` key has two distinct functionalities:
modification will be lost. After this masking operation, the ``"_reset"``
entries will be erased from the :meth:`~.EnvBase.reset` outputs.

It must be pointed that ``"_reset"`` is a private key, and it should only be
It must be pointed out that ``"_reset"`` is a private key, and it should only be
used when coding specific environment features that are internal facing.
In other words, this should NOT be used outside of the library, and developers
will keep the right to modify the logic of partial resets through ``"_reset"``
Expand All @@ -243,7 +243,7 @@ designing reset functionalities:
``any`` or ``all`` logic depending on the task).
- When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry
that will reset some but not all the done sub-environments, the input data
should contain the data of the sub-environemtns that are __not__ being reset.
should contain the data of the sub-environments that are __not__ being reset.
The reason for this constrain lies in the fact that the output of the
``env._reset(data)`` can only be predicted for the entries that are reset.
For the others, TorchRL cannot know in advance if they will be meaningful or
Expand All @@ -267,7 +267,7 @@ have on an environment returning zeros after reset:
>>> env.reset(data)
>>> print(data.get(("agent0", "val"))) # only the second value is 0
tensor([1, 0])
>>> print(data.get(("agent1", "val"))) # only the second value is 0
>>> print(data.get(("agent1", "val"))) # only the first value is 0
tensor([0, 2])
>>> # nested resets are overridden by a "_reset" at the root
>>> data = TensorDict({
Expand Down Expand Up @@ -573,7 +573,7 @@ Dynamic Specs
.. _dynamic_envs:

Running environments in parallel is usually done via the creation of memory buffers used to pass information from one
process to another. In some cases, it may be impossible to forecast whether and environment will or will not have
process to another. In some cases, it may be impossible to forecast whether an environment will or will not have
consistent inputs or outputs during a rollout, as their shape may be variable. We refer to this as dynamic specs.

TorchRL is capable of handling dynamic specs, but the batched environments and collectors will need to be made
Expand Down Expand Up @@ -670,9 +670,12 @@ Here is a working example:
is_shared=False,
stack_dim=0)

.. warning:: The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in data collectors can impact
performance of these classes dramatically. Any such usage should be carefully benchmarked against a plain execution on
a single process, as serializing and deserializing large numbers of tensors can be very expensive.
.. warning::
The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in
data collectors can impact performance of these classes dramatically. Any
such usage should be carefully benchmarked against a plain execution on a
single process, as serializing and deserializing large numbers of tensors
can be very expensive.

Currently, :func:`~torchrl.envs.utils.check_env_specs` will pass for dynamic specs where a shape varies along some
dimensions, but not when a key is present during a step and absent during others, or when the number of dimensions
Expand Down Expand Up @@ -941,7 +944,7 @@ formatted images (WHC or CWH).
>>> env.transform.dump() # Save the video and clear cache

Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to
take care of calling dumpy when needed to avoid OOM issues.
take care of calling `dump` when needed to avoid OOM issues.

In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible
(some libraries only allow one environment instance per workspace).
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3533,7 +3533,7 @@ class DTypeCastTransform(Transform):
>>> print(td.get("not_transformed").dtype)
torch.float32
The same behavior is the rule when environments are constructedw without
The same behavior is the rule when environments are constructed without
specifying the transform keys:
Examples:
Expand Down Expand Up @@ -3903,7 +3903,7 @@ class DoubleToFloat(DTypeCastTransform):
>>> print(td.get("not_transformed").dtype)
torch.float32
The same behavior is the rule when environments are constructedw without
The same behavior is the rule when environments are constructed without
specifying the transform keys:
Examples:
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ class SafeProbabilisticTensorDictSequential(
instances, terminating in ProbabilisticTensorDictModule, to be run
sequentially.
partial_tolerant (bool, optional): if ``True``, the input tensordict can miss some
of the input keys. If so, the only module that will be executed are those
who can be executed given the keys that are present. Also, if the input
of the input keys. If so, the only modules that will be executed are those
which can be executed given the keys that are present. Also, if the input
tensordict is a lazy stack of tensordicts AND if partial_tolerant is
``True`` AND if the stack does not have the required keys, then
TensorDictSequential will scan through the sub-tensordicts looking for those
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SafeSequential(TensorDictSequential, SafeModule):
Args:
modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially.
partial_tolerant (bool, optional): if ``True``, the input tensordict can miss some of the input keys.
If so, the only module that will be executed are those who can be executed given the keys that
If so, the only modules that will be executed are those which can be executed given the keys that
are present.
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is ``True`` AND if the
stack does not have the required keys, then SafeSequential will scan through the sub-tensordicts
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:

@property
def _alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
alpha = self.log_alpha.data.exp()
return alpha
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def alpha_loss(self, log_prob: Tensor) -> Tensor:

@property
def _alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _forward_value_estimator_keys(self, **kwargs):

@property
def alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
Expand Down
21 changes: 14 additions & 7 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class PPOLoss(LossModule):
Args:
actor_network (ProbabilisticTensorDictSequential): policy operator.
Typically a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations
Typically, a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations
as input and outputting an action (or actions) as well as its log-probability value.
critic_network (ValueOperator): value operator. The critic will usually take the observations as input
and return a scalar value (``state_value`` by default) in the output keys.
Expand Down Expand Up @@ -490,7 +490,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:

if is_tensor_collection(log_prob):
if isinstance(self.tensor_keys.sample_log_prob, NestedKey):
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
try:
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
except KeyError as err:
raise _make_lp_get_error(self.tensor_keys, log_prob, err)
else:
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)

Expand All @@ -511,9 +514,12 @@ def _log_weight(
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)

prev_log_prob = _maybe_get_or_select(
tensordict, self.tensor_keys.sample_log_prob
)
try:
prev_log_prob = _maybe_get_or_select(
tensordict, self.tensor_keys.sample_log_prob
)
except KeyError as err:
raise _make_lp_get_error(self.tensor_keys, tensordict, err)

if prev_log_prob.requires_grad:
raise RuntimeError(
Expand Down Expand Up @@ -930,15 +936,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
if is_tensor_collection(gain):
gain = gain.sum(reduce=True)
gain = _sum_td_features(gain)
td_out = TensorDict({"loss_objective": -gain}, batch_size=[])
td_out.set("clip_fraction", clip_fraction)

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
Expand Down Expand Up @@ -1223,6 +1229,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
current_log_prob = current_dist.log_prob(x, **kwargs)
if is_tensor_collection(previous_log_prob):
previous_log_prob = _sum_td_features(previous_log_prob)
# Both dists have presumably the same params
current_log_prob = _sum_td_features(current_log_prob)
kl = (previous_log_prob - current_log_prob).mean(0)
kl = kl.unsqueeze(-1)
Expand Down
4 changes: 2 additions & 2 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor:

@property
def _alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
Expand Down Expand Up @@ -1374,7 +1374,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor:

@property
def _alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data = self.log_alpha.data.clamp(
self.min_log_alpha, self.max_log_alpha
)
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/getting-started-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
print(reset_with_action["action"])

################################
# We now need to pass this action tp the environment.
# We now need to pass this action to the environment.
# We'll be passing the entire tensordict to the ``step`` method, since there
# might be more than one tensor to be read in more advanced cases like
# Multi-Agent RL or stateless environments:
Expand Down
7 changes: 3 additions & 4 deletions tutorials/sphinx-tutorials/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# sphinx_gallery_start_ignore
import warnings

from tensordict import LazyStackedTensorDict

warnings.filterwarnings("ignore")

from torch import multiprocessing
Expand All @@ -32,6 +30,7 @@

# sphinx_gallery_end_ignore

from tensordict import LazyStackedTensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn

Expand Down Expand Up @@ -91,7 +90,7 @@
# ^^^^^^
#
# We will design a policy where a backbone reads the "observation" key.
# Then specific sub-components will ready the "observation_stand" and
# Then specific sub-components will read the "observation_stand" and
# "observation_walk" keys of the stacked tensordicts, if they are present,
# and pass them through the dedicated sub-network.

Expand Down Expand Up @@ -138,7 +137,7 @@
# Executing diverse tasks in parallel
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We can parallelize the operations if the common keys-value pairs share the
# We can parallelize the operations if the common key-value pairs share the
# same specs (in particular their shape and dtype must match: you can't do the
# following if the observation shapes are different but are pointed to by the
# same key).
Expand Down

0 comments on commit 5c03f9f

Please sign in to comment.