Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 19, 2024
1 parent 2ecb709 commit 012f9c0
Show file tree
Hide file tree
Showing 27 changed files with 75 additions and 238 deletions.
11 changes: 3 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,11 @@ def _policy_is_tensordict_compatible(policy: nn.Module):
and hasattr(policy, "in_keys")
and hasattr(policy, "out_keys")
):
warnings.warn(
"Passing a policy that is not a TensorDictModuleBase subclass but has in_keys and out_keys "
"will soon be deprecated. We'd like to motivate our users to inherit from this class (which "
raise RuntimeError(
"Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys "
"is deprecated. Users should inherit from this class (which "
"has very few restrictions) to make the experience smoother.",
category=DeprecationWarning,
)
# if the policy is a TensorDictModule or takes a single argument and defines
# in_keys and out_keys then we assume it can already deal with TensorDict input
# to forward and we return True
return True
elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"):
# if it's not a TensorDictModule, and in_keys and out_keys are not defined then
# we assume no TensorDict compatibility and will try to wrap it.
Expand Down
31 changes: 9 additions & 22 deletions torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
prefetch: int | None = None,
transform: "torchrl.envs.Transform" | None = None, # noqa-F821
split_trajs: bool = False,
from_env: bool = None,
from_env: bool = False,
use_truncated_as_done: bool = True,
direct_download: bool = None,
terminate_on_end: bool = None,
Expand All @@ -165,29 +165,16 @@ def __init__(
direct_download = not self._has_d4rl

if not direct_download:
if from_env is None:
warnings.warn(
"from_env will soon default to ``False``, ie the data will be "
"downloaded without relying on d4rl by default. "
"For now, ``True`` will still be the default. "
"To disable this warning, explicitly pass the ``from_env`` argument "
"during construction of the dataset.",
category=DeprecationWarning,
)
from_env = True
else:
warnings.warn(
"You are using the D4RL library for collecting data. "
"We advise against this use, as D4RL formatting can be "
"inconsistent. "
"To download the D4RL data without the D4RL library, use "
"direct_download=True in the dataset constructor. "
"Recurring to `direct_download=False` will soon be deprecated."
)
warnings.warn(
"You are using the D4RL library for collecting data. "
"We advise against this use, as D4RL formatting can be "
"inconsistent. "
"To download the D4RL data without the D4RL library, use "
"direct_download=True in the dataset constructor. "
"Recurring to `direct_download=False` will soon be deprecated."
)
self.from_env = from_env
else:
if from_env is None:
from_env = False
self.from_env = from_env

if (download == "force") or (download and not self._is_downloaded()):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ def _reset_batch_size(x):
shape = x.get("_rb_batch_size", None)
if shape is not None:
warnings.warn(
"Reshaping nested tensordicts will be deprecated soon.",
"Reshaping nested tensordicts will be deprecated in v0.4.0.",
category=DeprecationWarning,
)
data = x.get("_data")
Expand Down
6 changes: 3 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,15 @@ def high(self, value):
@property
def minimum(self):
warnings.warn(
f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low",
f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low in v0.4.0",
category=DeprecationWarning,
)
return self._low.to(self.device)

@property
def maximum(self):
warnings.warn(
f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high",
f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high in v0.4.0",
category=DeprecationWarning,
)
return self._high.to(self.device)
Expand Down Expand Up @@ -1472,7 +1472,7 @@ class BoundedTensorSpec(TensorSpec):
# SPEC_HANDLED_FUNCTIONS = {}
DEPRECATED_KWARGS = (
"The `minimum` and `maximum` keyword arguments are now "
"deprecated in favour of `low` and `high`."
"deprecated in favour of `low` and `high` in v0.4.0."
)
CONFLICTING_KWARGS = (
"The keyword arguments {} and {} conflict. Only one of these can be passed."
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def info_dict_reader(self, value: callable):
warnings.warn(
f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. "
f"This method will append a reader to the list of existing readers (if any). "
f"Setting info_dict_reader directly will be soon deprecated.",
f"Setting info_dict_reader directly will be deprecated in v0.4.0.",
category=DeprecationWarning,
)
self._info_dict_reader.append(value)
12 changes: 3 additions & 9 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,15 +2345,9 @@ def __init__(
standard_normal: bool = False,
):
if in_keys is None:
warnings.warn(
"Not passing in_keys to ObservationNorm will soon be deprecated. "
"Ensure you specify the entries to be normalized",
category=DeprecationWarning,
raise RuntimeError(
"Not passing in_keys to ObservationNorm is a deprecated behaviour."
)
in_keys = [
"observation",
"pixels",
]

if out_keys is None:
out_keys = copy(in_keys)
Expand Down Expand Up @@ -2692,7 +2686,7 @@ def __init__(
raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}")
if padding == "zeros":
warnings.warn(
"Padding option 'zeros' will be deprecated in the future. "
"Padding option 'zeros' will be deprecated in v0.4.0. "
"Please use 'constant' padding with padding_value 0 instead.",
category=DeprecationWarning,
)
Expand Down
9 changes: 3 additions & 6 deletions torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,9 +872,6 @@ class DistributionalDQNnet(TensorDictModuleBase):
"""Distributional Deep Q-Network.
Args:
DQNet (nn.Module): (deprecated) Q-Network with output length equal
to the number of atoms:
output.shape = [*batch, atoms, actions].
in_keys (list of str or tuples of str): input keys to the log-softmax
operation. Defaults to ``["action_value"]``.
out_keys (list of str or tuples of str): output keys to the log-softmax
Expand All @@ -888,11 +885,11 @@ class DistributionalDQNnet(TensorDictModuleBase):
"instead."
)

def __init__(self, DQNet: nn.Module = None, in_keys=None, out_keys=None):
def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None):
super().__init__()
if DQNet is not None:
warnings.warn(
f"Passing a network to {type(self)} is going to be deprecated.",
f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.",
category=DeprecationWarning,
)
if not (
Expand Down Expand Up @@ -1280,7 +1277,7 @@ def __init__(
device: Optional[DEVICE_TYPING] = None,
) -> None:
warnings.warn(
"LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed soon.",
"LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.",
category=DeprecationWarning,
)
super().__init__()
Expand Down
10 changes: 5 additions & 5 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -825,7 +825,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -922,7 +922,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -1043,7 +1043,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down Expand Up @@ -1189,7 +1189,7 @@ def __init__(
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
"Using specs in action_space will be deprecated in v0.4.0,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
Expand Down
99 changes: 2 additions & 97 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,105 +247,10 @@ def __init__(
action_mask_key: Optional[NestedKey] = None,
spec: Optional[TensorSpec] = None,
):
warnings.warn(
"EGreedyWrapper is deprecated and it will be removed in v0.3. "
"Please use torchrl.modules.EGreedyModule instead.",
category=DeprecationWarning,
raise RuntimeError(
"This class is not removed in favour of torchrl.modules.EGreedyModule."
)

super().__init__(policy)
self.register_buffer("eps_init", torch.tensor([eps_init]))
self.register_buffer("eps_end", torch.tensor([eps_end]))
if self.eps_end > self.eps_init:
raise RuntimeError("eps should decrease over time or be constant")
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))
self.action_key = action_key
self.action_mask_key = action_mask_key
if spec is not None:
if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
if action_key not in self._spec.keys():
self._spec[action_key] = None
elif hasattr(self.td_module, "spec"):
self._spec = self.td_module.spec.clone()
if action_key not in self._spec.keys():
self._spec[action_key] = None
else:
self._spec = spec

@property
def spec(self):
return self._spec

def step(self, frames: int = 1) -> None:
"""A step of epsilon decay.
After self.annealing_num_steps, this function is a no-op.
Args:
frames (int): number of frames since last step.
"""
for _ in range(frames):
self.eps.data[0] = max(
self.eps_end.item(),
(
self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps
).item(),
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = self.td_module.forward(tensordict)
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
if isinstance(self.action_key, tuple) and len(self.action_key) > 1:
action_tensordict = tensordict.get(self.action_key[:-1])
action_key = self.action_key[-1]
else:
action_tensordict = tensordict
action_key = self.action_key

out = action_tensordict.get(action_key)
eps = self.eps.item()
cond = (
torch.rand(action_tensordict.shape, device=action_tensordict.device)
< eps
).to(out.dtype)
cond = expand_as_right(cond, out)
spec = self.spec
if spec is not None:
if isinstance(spec, CompositeSpec):
spec = spec[self.action_key]
if spec.shape != out.shape:
# In batched envs if the spec is passed unbatched, the rand() will not
# cover all batched dims
if (
not len(spec.shape)
or out.shape[-len(spec.shape) :] == spec.shape
):
spec = spec.expand(out.shape)
else:
raise ValueError(
"Action spec shape does not match the action shape"
)
if self.action_mask_key is not None:
action_mask = tensordict.get(self.action_mask_key, None)
if action_mask is None:
raise KeyError(
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
spec.update_mask(action_mask)
out = cond * spec.rand().to(out.device) + (1 - cond) * out
else:
raise RuntimeError(
"spec must be provided by the policy or directly to the exploration wrapper."
)
action_tensordict.set(action_key, out)
return tensordict


class AdditiveGaussianWrapper(TensorDictModuleWrapper):
"""Additive Gaussian PO wrapper.
Expand Down
9 changes: 2 additions & 7 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -555,11 +554,9 @@ def recurrent_mode(self, value):

@property
def temporal_mode(self):
warnings.warn(
raise RuntimeError(
"temporal_mode is deprecated, use recurrent_mode instead.",
category=DeprecationWarning,
)
return self.recurrent_mode

def set_recurrent_mode(self, mode: bool = True):
"""Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs).
Expand Down Expand Up @@ -1255,11 +1252,9 @@ def recurrent_mode(self, value):

@property
def temporal_mode(self):
warnings.warn(
raise RuntimeError(
"temporal_mode is deprecated, use recurrent_mode instead.",
category=DeprecationWarning,
)
return self.recurrent_mode

def set_recurrent_mode(self, mode: bool = True):
"""Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs).
Expand Down
6 changes: 2 additions & 4 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
Expand All @@ -21,7 +20,7 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
_GAMMA_LMBDA_DEPREC_ERROR,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down Expand Up @@ -261,8 +260,7 @@ def __init__(
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device))
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type

@property
Expand Down
Loading

0 comments on commit 012f9c0

Please sign in to comment.