From 133d709364df09664dba5be1aaac2a3c3dd4c6e2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 19 Dec 2024 10:25:23 +0000 Subject: [PATCH 1/7] [CI] Fix nightly build ghstack-source-id: 5502fa94b6abcc154e020dcb165093fdc30ca025 Pull Request resolved: https://github.com/pytorch/rl/pull/2666 --- .github/workflows/nightly_build.yml | 39 +++++++++++++++-------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index 08eb61bfa6c..732077f4b58 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -21,11 +21,6 @@ on: branches: - "nightly" -env: - ACTIONS_RUNNER_FORCED_INTERNAL_NODE_VERSION: node16 - ACTIONS_RUNNER_FORCE_ACTIONS_NODE_VERSION: node16 - ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true # https://github.com/actions/checkout/issues/1809 - concurrency: # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. @@ -41,12 +36,15 @@ jobs: matrix: python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] - container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 env: AGENT_TOOLSDIRECTORY: "/opt/hostedtoolcache" + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version[0] }} - name: Install PyTorch nightly run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" @@ -67,7 +65,7 @@ jobs: python3 -mpip install auditwheel auditwheel show dist/* - name: Upload wheel for the test-wheel job - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: dist/*.whl @@ -81,12 +79,15 @@ jobs: matrix: python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] - container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version[0] }} - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: /tmp/wheels @@ -121,7 +122,7 @@ jobs: env: AGENT_TOOLSDIRECTORY: "/opt/hostedtoolcache" - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch Nightly run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" @@ -138,7 +139,7 @@ jobs: export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install numpy pytest pillow>=4.1.1 scipy networkx expecttest pyyaml - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: /tmp/wheels @@ -179,7 +180,7 @@ jobs: with: python-version: ${{ matrix.python_version[1] }} - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch nightly shell: bash run: | @@ -193,7 +194,7 @@ jobs: --package_name torchrl-nightly \ --python-tag=${{ matrix.python-tag }} - name: Upload wheel for the test-wheel job - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: dist/*.whl @@ -212,7 +213,7 @@ jobs: with: python-version: ${{ matrix.python_version[1] }} - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch Nightly shell: bash run: | @@ -229,7 +230,7 @@ jobs: run: | python3 -mpip install git+https://github.com/pytorch/tensordict.git - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: wheels @@ -265,9 +266,9 @@ jobs: python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: wheels From f4709c143f727b379be7a5334ff77d9ce19d1986 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:42 +0000 Subject: [PATCH 2/7] [BugFix] Compatibility of tensordict primers with batched envs (specifically for LSTM and GRU) ghstack-source-id: e1da58ecfd36ca01b8a11fe90e5f3c5fe77f064c Pull Request resolved: https://github.com/pytorch/rl/pull/2668 --- .../decision_transformer/utils.py | 2 +- test/test_tensordictmodules.py | 105 +++++++++++++----- test/test_transforms.py | 45 ++++++-- torchrl/envs/batched_envs.py | 36 +++++- torchrl/envs/transforms/transforms.py | 37 ++++-- torchrl/modules/tensordict_module/rnn.py | 14 ++- 6 files changed, 191 insertions(+), 48 deletions(-) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index d4a67e7d3a9..415e19a1f7c 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -109,7 +109,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): ) # copy action from the input tensordict to the output - transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec)) + transformed_env.append_transform(TensorDictPrimer(base_env.full_action_spec)) transformed_env.append_transform(DoubleToFloat()) obsnorm = ObservationNorm( diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index d3b7b7850f4..c2a34f3797d 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import functools import os import pytest @@ -12,6 +13,7 @@ import torchrl.modules from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential +from tensordict.utils import assert_close from torch import nn from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs import ( @@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based): @pytest.mark.parametrize("python_based", [True, False]) @pytest.mark.parametrize("parallel", [True, False]) @pytest.mark.parametrize("heterogeneous", [True, False]) - def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): + @pytest.mark.parametrize("within", [False, True]) + def test_lstm_parallel_env(self, python_based, parallel, heterogeneous, within): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv torch.manual_seed(0) + num_envs = 3 device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs lstm_module = LSTMModule( @@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): else: cls = SerialEnv - def create_transformed_env(): - primer = lstm_module.make_tensordict_primer() - env = DiscreteActionVecMockEnv( - categorical_action_encoding=True, device=device + if within: + + def create_transformed_env(): + primer = lstm_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv( + categorical_action_encoding=True, device=device + ) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + else: + create_transformed_env = functools.partial( + DiscreteActionVecMockEnv, + categorical_action_encoding=True, + device=device, ) - env = TransformedEnv(env) - env.append_transform(InitTracker()) - env.append_transform(primer) - return env if heterogeneous: create_transformed_env = [ - EnvCreator(create_transformed_env), - EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env) for _ in range(num_envs) ] env = cls( create_env_fn=create_transformed_env, - num_workers=2, + num_workers=num_envs, ) + if not within: + env = env.append_transform(InitTracker()) + env.append_transform(lstm_module.make_tensordict_primer()) mlp = TensorDictModule( MLP( @@ -1002,6 +1017,19 @@ def create_transformed_env(): data = env.rollout(10, actor, break_when_any_done=break_when_any_done) assert (data.get(("next", "recurrent_state_c")) != 0.0).all() assert (data.get("recurrent_state_c") != 0.0).any() + return data + + @pytest.mark.parametrize("python_based", [True, False]) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_lstm_parallel_within(self, python_based, parallel, heterogeneous): + out_within = self.test_lstm_parallel_env( + python_based, parallel, heterogeneous, within=True + ) + out_not_within = self.test_lstm_parallel_env( + python_based, parallel, heterogeneous, within=False + ) + assert_close(out_within, out_not_within) @pytest.mark.skipif( not _has_functorch, reason="vmap can only be used with functorch" @@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based): @pytest.mark.parametrize("python_based", [True, False]) @pytest.mark.parametrize("parallel", [True, False]) @pytest.mark.parametrize("heterogeneous", [True, False]) - def test_gru_parallel_env(self, python_based, parallel, heterogeneous): + @pytest.mark.parametrize("within", [False, True]) + def test_gru_parallel_env(self, python_based, parallel, heterogeneous, within): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv torch.manual_seed(0) + num_workers = 3 device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs @@ -1347,15 +1377,24 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous): python_based=python_based, ) - def create_transformed_env(): - primer = gru_module.make_tensordict_primer() - env = DiscreteActionVecMockEnv( - categorical_action_encoding=True, device=device + if within: + + def create_transformed_env(): + primer = gru_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv( + categorical_action_encoding=True, device=device + ) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + else: + create_transformed_env = functools.partial( + DiscreteActionVecMockEnv, + categorical_action_encoding=True, + device=device, ) - env = TransformedEnv(env) - env.append_transform(InitTracker()) - env.append_transform(primer) - return env if parallel: cls = ParallelEnv @@ -1363,14 +1402,17 @@ def create_transformed_env(): cls = SerialEnv if heterogeneous: create_transformed_env = [ - EnvCreator(create_transformed_env), - EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env) for _ in range(num_workers) ] - env = cls( + env: ParallelEnv | SerialEnv = cls( create_env_fn=create_transformed_env, - num_workers=2, + num_workers=num_workers, ) + if not within: + primer = gru_module.make_tensordict_primer() + env = env.append_transform(InitTracker()) + env.append_transform(primer) mlp = TensorDictModule( MLP( @@ -1396,6 +1438,19 @@ def create_transformed_env(): data = env.rollout(10, actor, break_when_any_done=break_when_any_done) assert (data.get("recurrent_state") != 0.0).any() assert (data.get(("next", "recurrent_state")) != 0.0).all() + return data + + @pytest.mark.parametrize("python_based", [True, False]) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_gru_parallel_within(self, python_based, parallel, heterogeneous): + out_within = self.test_gru_parallel_env( + python_based, parallel, heterogeneous, within=True + ) + out_not_within = self.test_gru_parallel_env( + python_based, parallel, heterogeneous, within=False + ) + assert_close(out_within, out_not_within) @pytest.mark.skipif( not _has_functorch, reason="vmap can only be used with functorch" diff --git a/test/test_transforms.py b/test/test_transforms.py index cc3ca40b059..44ebce72c5c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7408,7 +7408,7 @@ def make_env(): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=Unbounded([2, 4])), + TensorDictPrimer(mykey=Unbounded([4])), ) try: check_env_specs(env) @@ -7423,11 +7423,39 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): pass @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) - def test_trans_serial_env_check(self, spec_shape): - env = TransformedEnv( - SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=Unbounded(spec_shape)), - ) + @pytest.mark.parametrize("expand_specs", [True, False, None]) + def test_trans_serial_env_check(self, spec_shape, expand_specs): + if expand_specs is None: + with pytest.warns(FutureWarning, match=""): + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + mykey=Unbounded(spec_shape), expand_specs=expand_specs + ), + ) + env.observation_spec + elif expand_specs is True: + shape = spec_shape[:-1] + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + Composite(mykey=Unbounded(spec_shape), shape=shape), + expand_specs=expand_specs, + ), + ) + else: + # If we don't expand, we can't use [4] + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + mykey=Unbounded(spec_shape), expand_specs=expand_specs + ), + ) + if spec_shape == [4]: + with pytest.raises(ValueError): + env.observation_spec + return + check_env_specs(env) assert "mykey" in env.reset().keys() r = env.rollout(3) @@ -10310,9 +10338,8 @@ def _make_transform_env(self, out_key, base_env): transform = KLRewardTransform(actor, out_keys=out_key) return Compose( TensorDictPrimer( - primers={ - "sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1]) - } + sample_log_prob=Unbounded(shape=base_env.action_spec.shape[:-1]), + shape=base_env.shape, ), transform, ) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 17bd28c8390..f7a25c1bd5c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1744,14 +1744,39 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We keep track of which keys are present to let the worker know what # should be passed to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) + next_shared_tensordict_parent = shared_tensordict_parent.get("next") + + # We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset. + # The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of + # the batched env but part of the specs of a transformed batched env. + # If that is the case, `update_` will fail to find the entries to update. + # What we do instead is keeping the tensors on the side and putting them back after completing _step. + keys_to_update, keys_to_copy = zip( + *[ + (key, None) + if key in next_shared_tensordict_parent.keys(True, True) + else (None, key) + for key in next_td_keys + ] + ) + keys_to_update = [key for key in keys_to_update if key is not None] + keys_to_copy = [key for key in keys_to_copy if key is not None] data = [ - {"next_td_passthrough_keys": next_td_keys} + {"next_td_passthrough_keys": keys_to_update} for _ in range(self.num_workers) ] - shared_tensordict_parent.get("next").update_( - next_td_passthrough, non_blocking=self.non_blocking - ) + if keys_to_update: + next_shared_tensordict_parent.update_( + next_td_passthrough, + non_blocking=self.non_blocking, + keys_to_update=keys_to_update, + ) + if keys_to_copy: + next_td_passthrough = next_td_passthrough.select(*keys_to_copy) + else: + next_td_passthrough = None else: + next_td_passthrough = None data = [{} for _ in range(self.num_workers)] if self._non_tensor_keys: @@ -1807,6 +1832,9 @@ def select_and_clone(name, tensor): LazyStackedTensorDict(*non_tensor_tds), keys_to_update=self._non_tensor_keys, ) + if next_td_passthrough is not None: + out.update(next_td_passthrough) + self._sync_w2m() if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f3329d085df..64fad524d94 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4984,6 +4984,7 @@ def __init__( | Dict[NestedKey, float] | Dict[NestedKey, Callable] = None, reset_key: NestedKey | None = None, + expand_specs: bool = None, **kwargs, ): self.device = kwargs.pop("device", None) @@ -4995,8 +4996,16 @@ def __init__( ) kwargs = primers if not isinstance(kwargs, Composite): - kwargs = Composite(kwargs) - self.primers = kwargs + shape = kwargs.pop("shape", None) + device = kwargs.pop("device", None) + if "batch_size" in kwargs.keys(): + extra_kwargs = {"batch_size": kwargs.pop("batch_size")} + else: + extra_kwargs = {} + primers = Composite(kwargs, device=device, shape=shape, **extra_kwargs) + self.primers = primers + self.expand_specs = expand_specs + if random and default_value: raise ValueError( "Setting random to True and providing a default_value are incompatible." @@ -5089,12 +5098,26 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: ) if self.primers.shape != observation_spec.shape: - try: - # We try to set the primer shape to the observation spec shape - self.primers.shape = observation_spec.shape - except ValueError: - # If we fail, we expand them to that shape + if self.expand_specs: self.primers = self._expand_shape(self.primers) + elif self.expand_specs is None: + warnings.warn( + f"expand_specs wasn't specified in the {type(self).__name__} constructor. " + f"The current behaviour is that the transform will attempt to set the shape of the composite " + f"spec, and if this can't be done it will be expanded. " + f"From v0.8, a mismatched shape between the spec of the transform and the env's batch_size " + f"will raise an exception.", + category=FutureWarning, + ) + try: + # We try to set the primer shape to the observation spec shape + self.primers.shape = observation_spec.shape + except ValueError: + # If we fail, we expand them to that shape + self.primers = self._expand_shape(self.primers) + else: + self.primers.shape = observation_spec.shape + device = observation_spec.device observation_spec.update(self.primers.clone().to(device)) return observation_spec diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index f4ceb648665..68309c346cd 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -592,6 +592,10 @@ def make_tensordict_primer(self): inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across processes and dealt with properly. + When using batched environments such as :class:`~torchrl.envs.ParallelEnv`, the transform can be used at the + single env instance level (i.e., a batch of transformed envs with tensordict primers set within) or at the + batched env instance level (i.e., a transformed batch of regular envs). + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states @@ -649,7 +653,8 @@ def make_tuple(key): { in_key1: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), in_key2: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), - } + }, + expand_specs=True, ) @property @@ -1410,6 +1415,10 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + When using batched environments such as :class:`~torchrl.envs.ParallelEnv`, the transform can be used at the + single env instance level (i.e., a batch of transformed envs with tensordict primers set within) or at the + batched env instance level (i.e., a transformed batch of regular envs). + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given module. @@ -1459,7 +1468,8 @@ def make_tuple(key): return TensorDictPrimer( { in_key1: Unbounded(shape=(self.gru.num_layers, self.gru.hidden_size)), - } + }, + expand_specs=True, ) @property From 21eeca42ca715e6b5b80560713e0f280cb825002 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:42 +0000 Subject: [PATCH 3/7] [BugFix] Avoid KeyError in slice sampler (for compile) ghstack-source-id: 6e2a3036f0e50d365387cced50a761b97a47317d Pull Request resolved: https://github.com/pytorch/rl/pull/2670 --- torchrl/data/replay_buffers/samplers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index bbdf2387683..2ad0550ed06 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1485,13 +1485,13 @@ def _get_index( truncated[seq_length.cumsum(0) - 1] = 1 index = index.to(torch.long).unbind(-1) st_index = storage[index] - try: - done = st_index[done_key] | truncated - except KeyError: + done = st_index.get(done_key, default=None) + if done is None: done = truncated.clone() - try: - terminated = st_index[terminated_key] - except KeyError: + else: + done = done | truncated + terminated = st_index.get(terminated_key, default=None) + if terminated is None: terminated = torch.zeros_like(truncated) return index, { truncated_key: truncated, From 4fd54fef493da9bd2084a574ff731f844c83913c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:43 +0000 Subject: [PATCH 4/7] [Performance] Avoid cloning trajs in SliceSampler ghstack-source-id: 2e133fcea716b202694cfa84df3f6e4ba3507bbc Pull Request resolved: https://github.com/pytorch/rl/pull/2671 --- torchrl/data/replay_buffers/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 2ad0550ed06..273cf627521 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1243,7 +1243,7 @@ def _get_stop_and_length(self, storage, fallback=True): "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) vals = self._find_start_stop_traj( - trajectory=trajectory.clone(), + trajectory=trajectory, at_capacity=storage._is_full, cursor=getattr(storage, "_last_cursor", None), ) From 84c3ec3221a9f9b6080cfe12e333bc38d3d099d4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:44 +0000 Subject: [PATCH 5/7] [Performance] Accelerate slice sampler on GPU ghstack-source-id: a4dc1515d8b51f5ec150b2fae4e1a84254f2af09 Pull Request resolved: https://github.com/pytorch/rl/pull/2672 --- torchrl/data/replay_buffers/samplers.py | 48 ++++++++++++++++++++----- torchrl/data/replay_buffers/utils.py | 8 +++++ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 273cf627521..fc27401d5e5 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -24,7 +24,7 @@ from torchrl._utils import _replace_last, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage -from torchrl.data.replay_buffers.utils import _is_int, unravel_index +from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index try: from torchrl._torchrl import ( @@ -726,6 +726,10 @@ class SliceSampler(Sampler): This class samples sub-trajectories with replacement. For a version without replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`. + .. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate + its execution, prefer using `end_key` over `traj_key`, and consider the following + keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`. + Keyword Args: num_slices (int): the number of slices to be sampled. The batch-size must be greater or equal to the ``num_slices`` argument. Exclusive @@ -796,6 +800,10 @@ class SliceSampler(Sampler): that at least `slice_len - i` samples will be gathered for each sampled trajectory. Using tuples allows a fine grained control over the span on the left (beginning of the stored trajectory) and on the right (end of the stored trajectory). + use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator + will be used to retrieve the indices of the trajectory starts. This can significanlty + accelerate the sampling when the buffer content is large. + Defaults to ``False``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first @@ -985,6 +993,7 @@ def __init__( strict_length: bool = True, compile: bool | dict = False, span: bool | int | Tuple[bool | int, bool | int] = False, + use_gpu: torch.device | bool = False, ): self.num_slices = num_slices self.slice_len = slice_len @@ -995,6 +1004,14 @@ def __init__( self._fetch_traj = True self.strict_length = strict_length self._cache = {} + self.use_gpu = bool(use_gpu) + self._gpu_device = ( + None + if not self.use_gpu + else torch.device(use_gpu) + if not isinstance(use_gpu, bool) + else _auto_device() + ) if isinstance(span, (bool, int)): span = (span, span) @@ -1086,9 +1103,8 @@ def __repr__(self): f"strict_length={self.strict_length})" ) - @classmethod def _find_start_stop_traj( - cls, *, trajectory=None, end=None, at_capacity: bool, cursor=None + self, *, trajectory=None, end=None, at_capacity: bool, cursor=None ): if trajectory is not None: # slower @@ -1141,10 +1157,15 @@ def _find_start_stop_traj( raise RuntimeError( "Expected the end-of-trajectory signal to be at least 1-dimensional." ) - return cls._end_to_start_stop(length=length, end=end) - - @staticmethod - def _end_to_start_stop(end, length): + return self._end_to_start_stop(length=length, end=end) + + def _end_to_start_stop(self, end, length): + device = None + if self.use_gpu: + gpu_device = self._gpu_device + if end.device != gpu_device: + device = end.device + end = end.to(self._gpu_device) # Using transpose ensures the start and stop are sorted the same way stop_idx = end.transpose(0, -1).nonzero() stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone() @@ -1171,6 +1192,8 @@ def _end_to_start_stop(end, length): pass lengths = stop_idx[:, 0] - start_idx[:, 0] + 1 lengths[lengths <= 0] = lengths[lengths <= 0] + length + if device is not None: + return start_idx.to(device), stop_idx.to(device), lengths.to(device) return start_idx, stop_idx, lengths def _start_to_end(self, st: torch.Tensor, length: int): @@ -1547,6 +1570,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): the sampler, and continuous sampling without replacement is currently not allowed. + .. note:: `SliceSamplerWithoutReplacement` can be slow to retrieve the trajectory indices. To accelerate + its execution, prefer using `end_key` over `traj_key`, and consider the following + keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`. + Keyword Args: drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped. If ``False``, this last sample will be kept. @@ -1589,6 +1616,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): the :meth:`~sample` method will be compiled with :func:`~torch.compile`. Keyword arguments can also be passed to torch.compile with this arg. Defaults to ``False``. + use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator + will be used to retrieve the indices of the trajectory starts. This can significanlty + accelerate the sampling when the buffer content is large. + Defaults to ``False``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first @@ -1693,7 +1724,6 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.]]) - """ def __init__( @@ -1710,6 +1740,7 @@ def __init__( strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False, + use_gpu: bool | torch.device = False, ): SliceSampler.__init__( self, @@ -1723,6 +1754,7 @@ def __init__( ends=ends, trajectories=trajectories, compile=compile, + use_gpu=use_gpu, ) SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle) diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index ef941a6ca90..1e8985537f3 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -1034,3 +1034,11 @@ def tree_iter(pytree): # noqa: F811 def tree_iter(pytree): # noqa: F811 """A version-compatible wrapper around tree_iter.""" yield from torch.utils._pytree.tree_iter(pytree) + + +def _auto_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda:0") + elif torch.mps.is_available(): + return torch.device("mps:0") + return torch.device("cpu") From ab4250ec712094d978d2071b195ed7f6dab00dd8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:46 +0000 Subject: [PATCH 6/7] [BugFix] Fix batching envs with non tensor data ghstack-source-id: daba8a95459cfa978da09291757b6380fab4f308 Pull Request resolved: https://github.com/pytorch/rl/pull/2674 --- torchrl/envs/batched_envs.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f7a25c1bd5c..5b6763f6910 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -730,19 +730,20 @@ def _create_td(self) -> None: ) ) env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys) - env_obs_keys = [ - key for key in env_obs_keys if key not in self._non_tensor_keys - ] - env_input_keys = [ - key for key in env_input_keys if key not in self._non_tensor_keys - ] - env_output_keys = [ - key for key in env_output_keys if key not in self._non_tensor_keys - ] self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self._env_input_keys = sorted(env_input_keys, key=_sort_keys) self._env_output_keys = sorted(env_output_keys, key=_sort_keys) + self._env_obs_keys = [ + key for key in self._env_obs_keys if key not in self._non_tensor_keys + ] + self._env_input_keys = [ + key for key in self._env_input_keys if key not in self._non_tensor_keys + ] + self._env_output_keys = [ + key for key in self._env_output_keys if key not in self._non_tensor_keys + ] + reset_keys = self.reset_keys self._selected_keys = ( set(self._env_output_keys) From d009835b4fccd1482e9a2bd2b597e00824b95c2d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 12:12:30 +0000 Subject: [PATCH 7/7] [Example] RNN-based policy example ghstack-source-id: ef0087e9b5cba40be428f57ef70ecd2f63483d03 Pull Request resolved: https://github.com/pytorch/rl/pull/2675 --- examples/agents/recurrent_actor.py | 205 +++++++++++++++++++++++ torchrl/modules/tensordict_module/rnn.py | 4 +- 2 files changed, 207 insertions(+), 2 deletions(-) create mode 100644 examples/agents/recurrent_actor.py diff --git a/examples/agents/recurrent_actor.py b/examples/agents/recurrent_actor.py new file mode 100644 index 00000000000..16ec64be626 --- /dev/null +++ b/examples/agents/recurrent_actor.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +""" +This code exemplifies how an actor that uses a RNN backbone can be built. + +It is based on snippets from the DQN with RNN tutorial. + +There are two main APIs to be aware of when using RNNs, and dedicated notes regarding these can be found at the end +of this example: the `set_recurrent_mode` context manager, and the `make_tensordict_primer` method. + +""" +from collections import OrderedDict + +import torch +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq +from torch import nn + +from torchrl.envs import ( + Compose, + GrayScale, + GymEnv, + InitTracker, + ObservationNorm, + Resize, + RewardScaling, + StepCounter, + ToTensorImage, + TransformedEnv, +) +from torchrl.modules import ConvNet, LSTMModule, MLP, QValueModule, set_recurrent_mode + +# Define the device to use for computations (GPU if available, otherwise CPU) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# Create a transformed environment using the CartPole-v1 gym environment +env = TransformedEnv( + GymEnv("CartPole-v1", from_pixels=True, device=device), + # Apply a series of transformations to the environment: + # 1. Convert observations to tensor images + # 2. Convert images to grayscale + # 3. Resize images to 84x84 pixels + # 4. Keep track of the step count + # 5. Initialize a tracker for the environment + # 6. Scale rewards by a factor of 0.1 + # 7. Normalize observations to have zero mean and unit variance (we'll adapt that dynamically later) + Compose( + ToTensorImage(), + GrayScale(), + Resize(84, 84), + StepCounter(), + InitTracker(), + RewardScaling(loc=0.0, scale=0.1), + ObservationNorm(standard_normal=True, in_keys=["pixels"]), + ), +) + +# Initialize the normalization statistics for the observation norm transform +env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0]) + +# Reset the environment to get an initial observation +td = env.reset() + +# Define a feature extractor module that takes pixel observations as input +# and outputs an embedding vector +feature = Mod( + ConvNet( + num_cells=[32, 32, 64], + squeeze_output=True, + aggregator_class=nn.AdaptiveAvgPool2d, + aggregator_kwargs={"output_size": (1, 1)}, + device=device, + ), + in_keys=["pixels"], + out_keys=["embed"], +) + +# Get the shape of the embedding vector output by the feature extractor +with torch.no_grad(): + n_cells = feature(env.reset())["embed"].shape[-1] + +# Define an LSTM module that takes the embedding vector as input and outputs +# a new embedding vector +lstm = LSTMModule( + input_size=n_cells, + hidden_size=128, + device=device, + in_key="embed", + out_key="embed", +) + +# Define a multi-layer perceptron (MLP) module that takes the LSTM output as +# input and outputs action values +mlp = MLP( + out_features=2, + num_cells=[ + 64, + ], + device=device, +) + +# Initialize the bias of the last layer of the MLP to zero +mlp[-1].bias.data.fill_(0.0) + +# Wrap the MLP in a TensorDictModule to handle input/output keys +mlp = Mod(mlp, in_keys=["embed"], out_keys=["action_value"]) + +# Define a Q-value module that computes the Q-value of the current state +qval = QValueModule(action_space=None, spec=env.action_spec) + +# Add a TensorDictPrimer to the environment to ensure that the policy is aware +# of the supplementary inputs and outputs (recurrent states) during rollout execution +# This is necessary when using batched environments or parallel data collection +env.append_transform(lstm.make_tensordict_primer()) + +# Create a sequential module that combines the feature extractor, LSTM, MLP, and Q-value modules +policy = Seq(OrderedDict(feature=feature, lstm=lstm, mlp=mlp, qval=qval)) + +# Roll out the policy in the environment for 100 steps +rollout = env.rollout(100, policy) +print(rollout) + +# Print result: +# +# TensorDict( +# fields={ +# action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False), +# action_value: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False), +# chosen_action_value: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), +# done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), +# embed: Tensor(shape=torch.Size([10, 128]), device=cpu, dtype=torch.float32, is_shared=False), +# is_init: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), +# next: TensorDict( +# fields={ +# done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), +# is_init: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), +# pixels: Tensor(shape=torch.Size([10, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False), +# recurrent_state_c: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False), +# recurrent_state_h: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False), +# reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), +# step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), +# terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), +# truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, +# batch_size=torch.Size([10]), +# device=cpu, +# is_shared=False), +# pixels: Tensor(shape=torch.Size([10, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False), +# recurrent_state_c: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False), +# recurrent_state_h: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False), +# step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), +# terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), +# truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, +# batch_size=torch.Size([10]), +# device=cpu, +# is_shared=False) +# + +# Notes: +# 1. make_tensordict_primer +# +# Regarding make_tensordict_primer, it creates a TensorDictPrimer object that ensures the policy is aware +# of the supplementary inputs and outputs (recurrent states) during rollout execution. +# This is necessary when using batched environments or parallel data collection, as the recurrent states +# need to be shared across processes and dealt with properly. +# +# In other words, make_tensordict_primer adds the LSTM's hidden states to the environment's specs, +# allowing the environment to properly handle the recurrent states during rollouts. Without it, the policy +# would not be able to use the LSTM's memory buffers correctly, leading to poorly defined behaviors, +# especially in parallel settings. +# +# By adding the TensorDictPrimer to the environment, you ensure that the policy can correctly use the +# LSTM's recurrent states, even when running in parallel or batched environments. This is why +# env.append_transform(lstm.make_tensordict_primer()) is called before creating the policy and rolling it +# out in the environment. +# +# 2. Using the LSTM to process multiple steps at once. +# +# When set_recurrent_mode("recurrent") is used, the LSTM will process the entire input tensordict as a sequence, using +# its recurrent connections to maintain state across time steps. This mode may utilize CuDNN to accelerate the processing +# of the sequence on CUDA devices. The behavior in this mode is akin to torch.nn.LSTM, where the LSTM expects the input +# data to be organized in batches of sequences. +# +# On the other hand, when set_recurrent_mode("sequential") is used, the +# LSTM will process each step in the input tensordict independently, without maintaining any state across time steps. This +# mode makes the LSTM behave similarly to torch.nn.LSTMCell, where each input is treated as a separate, independent +# element. +# +# In the example code, set_recurrent_mode("recurrent") is used to process a tensordict of shape [T], where T +# is the number of steps. This allows the LSTM to use its recurrent connections to maintain state across the entire +# sequence. +# +# In contrast, set_recurrent_mode("sequential") is used to process a single step from the tensordict (i.e., +# rollout[0]). In this case, the LSTM does not use its recurrent connections, and simply processes the single step as if +# it were an independent input. + +with set_recurrent_mode("recurrent"): + # Process a tensordict of shape [T] where T is a number of steps + print(policy(rollout)) + +with set_recurrent_mode("sequential"): + # Process a tensordict of shape [T] where T is a number of steps + print(policy(rollout[0])) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 68309c346cd..07bf0337c4e 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -1652,8 +1652,8 @@ class set_recurrent_mode(_DecoratorContextManager): """Context manager for setting RNNs recurrent mode. Args: - mode (bool, "recurrent" or "stateful"): the recurrent mode to be used within the context manager. - `"recurrent"` leads to `mode=True` and `"stateful"` leads to `mode=False`. + mode (bool, "recurrent" or "sequential"): the recurrent mode to be used within the context manager. + `"recurrent"` leads to `mode=True` and `"sequential"` leads to `mode=False`. An RNN executed with recurrent_mode "on" assumes that the data comes in time batches, otherwise it is assumed that each data element in a tensordict is independent of the others. The default value of this context manager is ``True``.