diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 9ef9d88dbe6..065d6a2e3d4 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -218,7 +218,7 @@ The ``"_reset"`` key has two distinct functionalities: modification will be lost. After this masking operation, the ``"_reset"`` entries will be erased from the :meth:`~.EnvBase.reset` outputs. -It must be pointed that ``"_reset"`` is a private key, and it should only be +It must be pointed out that ``"_reset"`` is a private key, and it should only be used when coding specific environment features that are internal facing. In other words, this should NOT be used outside of the library, and developers will keep the right to modify the logic of partial resets through ``"_reset"`` @@ -243,7 +243,7 @@ designing reset functionalities: ``any`` or ``all`` logic depending on the task). - When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry that will reset some but not all the done sub-environments, the input data - should contain the data of the sub-environemtns that are __not__ being reset. + should contain the data of the sub-environments that are __not__ being reset. The reason for this constrain lies in the fact that the output of the ``env._reset(data)`` can only be predicted for the entries that are reset. For the others, TorchRL cannot know in advance if they will be meaningful or @@ -267,7 +267,7 @@ have on an environment returning zeros after reset: >>> env.reset(data) >>> print(data.get(("agent0", "val"))) # only the second value is 0 tensor([1, 0]) - >>> print(data.get(("agent1", "val"))) # only the second value is 0 + >>> print(data.get(("agent1", "val"))) # only the first value is 0 tensor([0, 2]) >>> # nested resets are overridden by a "_reset" at the root >>> data = TensorDict({ @@ -573,7 +573,7 @@ Dynamic Specs .. _dynamic_envs: Running environments in parallel is usually done via the creation of memory buffers used to pass information from one -process to another. In some cases, it may be impossible to forecast whether and environment will or will not have +process to another. In some cases, it may be impossible to forecast whether an environment will or will not have consistent inputs or outputs during a rollout, as their shape may be variable. We refer to this as dynamic specs. TorchRL is capable of handling dynamic specs, but the batched environments and collectors will need to be made @@ -670,9 +670,12 @@ Here is a working example: is_shared=False, stack_dim=0) -.. warning:: The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in data collectors can impact -performance of these classes dramatically. Any such usage should be carefully benchmarked against a plain execution on -a single process, as serializing and deserializing large numbers of tensors can be very expensive. +.. warning:: + The absence of memory buffers in :class:`~torchrl.envs.ParallelEnv` and in + data collectors can impact performance of these classes dramatically. Any + such usage should be carefully benchmarked against a plain execution on a + single process, as serializing and deserializing large numbers of tensors + can be very expensive. Currently, :func:`~torchrl.envs.utils.check_env_specs` will pass for dynamic specs where a shape varies along some dimensions, but not when a key is present during a step and absent during others, or when the number of dimensions @@ -941,7 +944,7 @@ formatted images (WHC or CWH). >>> env.transform.dump() # Save the video and clear cache Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to -take care of calling dumpy when needed to avoid OOM issues. +take care of calling `dump` when needed to avoid OOM issues. In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible (some libraries only allow one environment instance per workspace). diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 64fad524d94..8e074fa8679 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3533,7 +3533,7 @@ class DTypeCastTransform(Transform): >>> print(td.get("not_transformed").dtype) torch.float32 - The same behavior is the rule when environments are constructedw without + The same behavior is the rule when environments are constructed without specifying the transform keys: Examples: @@ -3903,7 +3903,7 @@ class DoubleToFloat(DTypeCastTransform): >>> print(td.get("not_transformed").dtype) torch.float32 - The same behavior is the rule when environments are constructedw without + The same behavior is the rule when environments are constructed without specifying the transform keys: Examples: diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 5ea006b8d2f..4eb6e702c31 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -213,8 +213,8 @@ class SafeProbabilisticTensorDictSequential( instances, terminating in ProbabilisticTensorDictModule, to be run sequentially. partial_tolerant (bool, optional): if ``True``, the input tensordict can miss some - of the input keys. If so, the only module that will be executed are those - who can be executed given the keys that are present. Also, if the input + of the input keys. If so, the only modules that will be executed are those + which can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is ``True`` AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 938843e624f..44209c8335b 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -23,7 +23,7 @@ class SafeSequential(TensorDictSequential, SafeModule): Args: modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially. partial_tolerant (bool, optional): if ``True``, the input tensordict can miss some of the input keys. - If so, the only module that will be executed are those who can be executed given the keys that + If so, the only modules that will be executed are those which can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is ``True`` AND if the stack does not have the required keys, then SafeSequential will scan through the sub-tensordicts diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 375e3834dfc..6e056589a8c 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -892,7 +892,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) alpha = self.log_alpha.data.exp() return alpha diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 801180901a7..22e84673641 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -677,7 +677,7 @@ def alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 013e28713bf..a0d193acbfc 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -171,7 +171,7 @@ def _forward_value_estimator_keys(self, **kwargs): @property def alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 9e833d0518b..079a1efa92c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -69,7 +69,7 @@ class PPOLoss(LossModule): Args: actor_network (ProbabilisticTensorDictSequential): policy operator. - Typically a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations + Typically, a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations as input and outputting an action (or actions) as well as its log-probability value. critic_network (ValueOperator): value operator. The critic will usually take the observations as input and return a scalar value (``state_value`` by default) in the output keys. @@ -490,7 +490,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: if is_tensor_collection(log_prob): if isinstance(self.tensor_keys.sample_log_prob, NestedKey): - log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + try: + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + except KeyError as err: + raise _make_lp_get_error(self.tensor_keys, log_prob, err) else: log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) @@ -511,9 +514,12 @@ def _log_weight( ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - prev_log_prob = _maybe_get_or_select( - tensordict, self.tensor_keys.sample_log_prob - ) + try: + prev_log_prob = _maybe_get_or_select( + tensordict, self.tensor_keys.sample_log_prob + ) + except KeyError as err: + raise _make_lp_get_error(self.tensor_keys, tensordict, err) if prev_log_prob.requires_grad: raise RuntimeError( @@ -930,7 +936,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0] if is_tensor_collection(gain): - gain = gain.sum(reduce=True) + gain = _sum_td_features(gain) td_out = TensorDict({"loss_objective": -gain}, batch_size=[]) td_out.set("clip_fraction", clip_fraction) @@ -938,7 +944,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: entropy = self.get_entropy_bonus(dist) td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("kl_approx", kl_approx.detach().mean()) # for logging - td_out.set("loss_entropy", -self.entropy_coef * entropy.mean()) + td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) @@ -1223,6 +1229,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: current_log_prob = current_dist.log_prob(x, **kwargs) if is_tensor_collection(previous_log_prob): previous_log_prob = _sum_td_features(previous_log_prob) + # Both dists have presumably the same params current_log_prob = _sum_td_features(current_log_prob) kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index dafff17011e..eae6b7feb34 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -846,7 +846,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() @@ -1374,7 +1374,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data = self.log_alpha.data.clamp( self.min_log_alpha, self.max_log_alpha ) diff --git a/tutorials/sphinx-tutorials/getting-started-0.py b/tutorials/sphinx-tutorials/getting-started-0.py index 45e802ff5b7..8638b775bac 100644 --- a/tutorials/sphinx-tutorials/getting-started-0.py +++ b/tutorials/sphinx-tutorials/getting-started-0.py @@ -106,7 +106,7 @@ print(reset_with_action["action"]) ################################ -# We now need to pass this action tp the environment. +# We now need to pass this action to the environment. # We'll be passing the entire tensordict to the ``step`` method, since there # might be more than one tensor to be read in more advanced cases like # Multi-Agent RL or stateless environments: diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index 9ba094e1f20..f9a716f5065 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -11,8 +11,6 @@ # sphinx_gallery_start_ignore import warnings -from tensordict import LazyStackedTensorDict - warnings.filterwarnings("ignore") from torch import multiprocessing @@ -32,6 +30,7 @@ # sphinx_gallery_end_ignore +from tensordict import LazyStackedTensorDict from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn @@ -91,7 +90,7 @@ # ^^^^^^ # # We will design a policy where a backbone reads the "observation" key. -# Then specific sub-components will ready the "observation_stand" and +# Then specific sub-components will read the "observation_stand" and # "observation_walk" keys of the stacked tensordicts, if they are present, # and pass them through the dedicated sub-network. @@ -138,7 +137,7 @@ # Executing diverse tasks in parallel # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# We can parallelize the operations if the common keys-value pairs share the +# We can parallelize the operations if the common key-value pairs share the # same specs (in particular their shape and dtype must match: you can't do the # following if the observation shapes are different but are pointed to by the # same key).