Skip to content

Commit

Permalink
[Refactor] Limit the deepcopies in collectors
Browse files Browse the repository at this point in the history
ghstack-source-id: 876431a03550b9fe933dc4c53a3f949cdb3abd1c
Pull Request resolved: #2451
  • Loading branch information
vmoens committed Oct 1, 2024
1 parent 5851652 commit 1858bea
Show file tree
Hide file tree
Showing 17 changed files with 516 additions and 209 deletions.
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export DISPLAY=:0
export SDL_VIDEODRIVER=dummy

# legacy from bash scripts: remove?
conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False
conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG

pip3 install pip --upgrade
pip install virtualenv
Expand Down
57 changes: 56 additions & 1 deletion benchmarks/test_collectors_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import time

import pytest
import torch.cuda
import tqdm

from torchrl.collectors import SyncDataCollector
from torchrl.collectors.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
)
from torchrl.envs import EnvCreator, GymEnv, StepCounter, TransformedEnv
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.utils import RandomPolicy

Expand Down Expand Up @@ -180,6 +184,57 @@ def test_async_pixels(benchmark):
benchmark(execute_collector, c)


class TestRBGCollector:
@pytest.mark.parametrize(
"n_col,n_wokrers_per_col",
[
[2, 2],
[4, 2],
[8, 2],
[16, 2],
[2, 1],
[4, 1],
[8, 1],
[16, 1],
],
)
def test_multiasync_rb(self, n_col, n_wokrers_per_col):
make_env = EnvCreator(lambda: GymEnv("ALE/Pong-v5"))
if n_wokrers_per_col > 1:
make_env = ParallelEnv(n_wokrers_per_col, make_env)
env = make_env
policy = RandomPolicy(env.action_spec)
else:
env = make_env()
policy = RandomPolicy(env.action_spec)

storage = LazyTensorStorage(10_000)
rb = ReplayBuffer(storage=storage)
rb.extend(env.rollout(2, policy).reshape(-1))
rb.append_transform(CloudpickleWrapper(lambda x: x.reshape(-1)), invert=True)

fpb = n_wokrers_per_col * 100
total_frames = n_wokrers_per_col * 100_000
c = MultiaSyncDataCollector(
[make_env] * n_col,
policy,
frames_per_batch=fpb,
total_frames=total_frames,
replay_buffer=rb,
)
frames = 0
pbar = tqdm.tqdm(total=total_frames - (n_col * fpb))
for i, _ in enumerate(c):
if i == n_col:
t0 = time.time()
if i >= n_col:
frames += fpb
if i > n_col:
fps = frames / (time.time() - t0)
pbar.update(fpb)
pbar.set_description(f"fps: {fps: 4.4f}")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Binary file added docs/source/_static/img/collector-copy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ delivers batches of data on a first-come, first-serve basis, whereas
:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from
each sub-collector before delivering it.

Collectors and policy copies
----------------------------

When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to
keep the training version of the policy on a device and the inference version on another. For example, if you have two
CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the
case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one
device to the other (if no copy is required, this method is a no-op).

Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the
policy structure and copy the parameters placed on the new device during instantiation if necessary.
Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we
try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur.

.. figure:: /_static/img/collector-copy.png

Policy copy decision tree in Collectors.


Collectors and replay buffers interoperability
----------------------------------------------

Expand Down
2 changes: 2 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def get_default_devices():
return [torch.device("cpu")]
elif num_cuda == 1:
return [torch.device("cuda:0")]
elif torch.mps.is_available():
return [torch.device("mps:0")]
else:
# then run on all devices
return get_available_devices()
Expand Down
164 changes: 153 additions & 11 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
TensorDict,
TensorDictBase,
)
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential
from tensordict.nn import (
CudaGraphModule,
TensorDictModule,
TensorDictModuleBase,
TensorDictSequential,
)

from torch import nn
from torchrl._utils import (
Expand All @@ -76,6 +81,7 @@
TensorSpec,
Unbounded,
)
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import (
EnvBase,
EnvCreator,
Expand Down Expand Up @@ -1597,8 +1603,8 @@ def test_auto_wrap_error(self, collector_class, env_maker):
policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1])
with pytest.raises(
TypeError,
match=(r"Arguments to policy.forward are incompatible with entries in"),
) if collector_class is SyncDataCollector else pytest.raises(EOFError):
match=("Arguments to policy.forward are incompatible with entries in"),
):
collector_class(
**self._create_collector_kwargs(env_maker, collector_class, policy)
)
Expand Down Expand Up @@ -1827,10 +1833,15 @@ def test_set_truncated(collector_cls):
NestedCountingEnv(), InitTracker()
).add_truncated_keys()
env = env_fn()
policy = env.rand_action
policy = CloudpickleWrapper(env.rand_action)
if collector_cls == SyncDataCollector:
collector = collector_cls(
env, policy=policy, frames_per_batch=20, total_frames=-1, set_truncated=True
env,
policy=policy,
frames_per_batch=20,
total_frames=-1,
set_truncated=True,
trust_policy=True,
)
else:
collector = collector_cls(
Expand All @@ -1840,6 +1851,7 @@ def test_set_truncated(collector_cls):
total_frames=-1,
cat_results="stack",
set_truncated=True,
trust_policy=True,
)
try:
for data in collector:
Expand Down Expand Up @@ -2147,7 +2159,10 @@ def test_multi_collector_consistency(
assert_allclose_td(c2.unsqueeze(0), d2)


@pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda")
@pytest.mark.skipif(
not torch.cuda.is_available() and not torch.mps.is_available(),
reason="No casting if no cuda",
)
class TestUpdateParams:
class DummyEnv(EnvBase):
def __init__(self, device, batch_size=[]): # noqa: B006
Expand Down Expand Up @@ -2205,8 +2220,8 @@ def forward(self, td):
@pytest.mark.parametrize(
"policy_device,env_device",
[
["cpu", "cuda"],
["cuda", "cpu"],
["cpu", get_default_devices()[0]],
[get_default_devices()[0], "cpu"],
# ["cpu", "cuda:0"], # 1226: faster execution
# ["cuda:0", "cpu"],
# ["cuda", "cuda:0"],
Expand All @@ -2230,9 +2245,7 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
policy.param.data += 1
policy.buf.data += 2
if give_weights:
d = dict(policy.named_parameters())
d.update(policy.named_buffers())
p_w = TensorDict(d, [])
p_w = TensorDict.from_module(policy)
else:
p_w = None
col.update_policy_weights_(p_w)
Expand Down Expand Up @@ -2909,6 +2922,135 @@ def test_collector_rb_multiasync(
assert (idsdiff >= 0).all()


def __deepcopy_error__(*args, **kwargs):
raise RuntimeError("deepcopy not allowed")


@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize(
"collector_type",
[
SyncDataCollector,
MultiaSyncDataCollector,
functools.partial(MultiSyncDataCollector, cat_results="stack"),
],
)
def test_no_deepcopy_policy(collector_type):
# Tests that the collector instantiation does not make a deepcopy of the policy if not necessary.
#
# The only situation where we want to deepcopy the policy is when the policy_device differs from the actual device
# of the policy. This can only be checked if the policy is an nn.Module and any of the params is not on the desired
# device.
#
# If the policy is not a nn.Module or has no parameter, policy_device should warn (we don't know what to do but we
# can trust that the user knows what to do).

shared_device = torch.device("cpu")
if torch.cuda.is_available():
original_device = torch.device("cuda:0")
elif torch.mps.is_available():
original_device = torch.device("mps")
else:
pytest.skip("No GPU or MPS device")

def make_policy(device=None, nn_module=True):
if nn_module:
return TensorDictModule(
nn.Linear(7, 7, device=device),
in_keys=["observation"],
out_keys=["action"],
)
policy = make_policy(device=device)
return CloudpickleWrapper(policy)

def make_and_test_policy(
policy,
policy_device=None,
env_device=None,
device=None,
trust_policy=None,
):
# make sure policy errors when copied

policy.__deepcopy__ = __deepcopy_error__
envs = ContinuousActionVecMockEnv(device=env_device)
if collector_type is not SyncDataCollector:
envs = [envs, envs]
c = collector_type(
envs,
policy=policy,
total_frames=1000,
frames_per_batch=100,
policy_device=policy_device,
env_device=env_device,
device=device,
trust_policy=trust_policy,
)
for _ in c:
return

# Simplest use cases
policy = make_policy()
make_and_test_policy(policy)

if collector_type is SyncDataCollector or original_device.type != "mps":
# mps cannot be shared
policy = make_policy(device=original_device)
make_and_test_policy(policy, env_device=original_device)

if collector_type is SyncDataCollector or original_device.type != "mps":
policy = make_policy(device=original_device)
make_and_test_policy(
policy, policy_device=original_device, env_device=original_device
)

# a deepcopy must occur when the policy_device differs from the actual device
with pytest.raises(RuntimeError, match="deepcopy not allowed"):
policy = make_policy(device=original_device)
make_and_test_policy(
policy, policy_device=shared_device, env_device=shared_device
)

# a deepcopy must occur when device differs from the actual device
with pytest.raises(RuntimeError, match="deepcopy not allowed"):
policy = make_policy(device=original_device)
make_and_test_policy(policy, device=shared_device)

# If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device
# is there to inform us
substitute_device = (
original_device if torch.cuda.is_available() else torch.device("cpu")
)
policy = make_policy(substitute_device, nn_module=False)
with pytest.warns(UserWarning):
make_and_test_policy(
policy, policy_device=substitute_device, env_device=substitute_device
)
# For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device
with pytest.warns(UserWarning):
make_and_test_policy(
policy, policy_device=substitute_device, env_device=shared_device
)
make_and_test_policy(
policy,
policy_device=substitute_device,
env_device=shared_device,
trust_policy=True,
)

# If there is no policy_device, we assume that the user is doing things right too but don't warn
if collector_type is SyncDataCollector or original_device.type != "mps":
policy = make_policy(original_device, nn_module=False)
make_and_test_policy(policy, env_device=original_device)

# If the policy is a CudaGraphModule, we know it's on cuda - no need to warn
if torch.cuda.is_available():
with pytest.warns(UserWarning, match="Tensordict is registered in PyTree"):
policy = make_policy(original_device)
cudagraph_policy = CudaGraphModule(policy)
make_and_test_policy(cudagraph_policy, policy_device=original_device)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 1 addition & 1 deletion test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _test_distributed_collector_updatepolicy(
MultiaSyncDataCollector,
],
)
@pytest.mark.parametrize("update_interval", [1_000_000, 1])
@pytest.mark.parametrize("update_interval", [1])
def test_distributed_collector_updatepolicy(self, collector_class, update_interval):
"""Testing various collector classes to be used in nodes."""
queue = mp.Queue(1)
Expand Down
4 changes: 3 additions & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

VERBOSE = strtobool(os.environ.get("VERBOSE", "0"))
VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))
_os_is_windows = sys.platform == "win32"
RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1"))
if RL_WARNINGS:
Expand Down Expand Up @@ -785,4 +785,6 @@ def _make_ordinal_device(device: torch.device):
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
if device.type == "mps" and device.index is None:
return torch.device("mps", index=0)
return device
Loading

0 comments on commit 1858bea

Please sign in to comment.