Skip to content

Commit

Permalink
[Feature] Make benchmarked losses compatible with torch.compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 699a6bb6e4cb8982d09bf9aae447659e55e7f45e
Pull Request resolved: #2405
  • Loading branch information
vmoens committed Aug 30, 2024
1 parent e82a69f commit 94d44e5
Show file tree
Hide file tree
Showing 10 changed files with 465 additions and 142 deletions.
188 changes: 172 additions & 16 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from tensordict import TensorDict
from tensordict.nn import (
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule as ProbMod,
ProbabilisticTensorDictSequential as ProbSeq,
Expand Down Expand Up @@ -137,7 +138,10 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps):
)


def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128):
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_dqn_speed(
benchmark, compile, n_obs=8, n_act=4, depth=3, ncells=128, batch=128
):
net = MLP(in_features=n_obs, out_features=n_act, depth=depth, num_cells=ncells)
action_space = "one-hot"
mod = QValueActor(net, in_keys=["obs"], action_space=action_space)
Expand All @@ -155,10 +159,23 @@ def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128):
[batch],
)
loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_ddpg_speed(
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
):
common = MLP(
num_cells=ncells,
in_features=n_obs,
Expand Down Expand Up @@ -200,10 +217,23 @@ def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
loss = DDPGLoss(actor, value)

loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_sac_speed(
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
):
common = MLP(
num_cells=ncells,
in_features=n_obs,
Expand Down Expand Up @@ -245,6 +275,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
distribution_kwargs={"safe_tanh": False},
),
)
value_head = Mod(
Expand All @@ -256,10 +287,23 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_redq_speed(
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
):
common = MLP(
num_cells=ncells,
in_features=n_obs,
Expand Down Expand Up @@ -313,11 +357,22 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_redq_deprec_speed(
benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
):
common = MLP(
num_cells=ncells,
Expand Down Expand Up @@ -372,10 +427,23 @@ def test_redq_deprec_speed(
loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_td3_speed(
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
):
common = MLP(
num_cells=ncells,
in_features=n_obs,
Expand Down Expand Up @@ -417,14 +485,23 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
distribution_kwargs={"safe_tanh": False},
return_log_prob=True,
default_interaction_type=InteractionType.DETERMINISTIC,
),
)
value_head = Mod(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
value(actor(td))
value(actor(td.clone()))
if compile:
actor_c = torch.compile(actor.get_dist, fullgraph=True)
actor_c(td)
actor_c = torch.compile(actor, fullgraph=True)
actor_c(td)
value_c = torch.compile(value, fullgraph=True)
value_c(td)

loss = TD3Loss(
actor,
Expand All @@ -433,10 +510,23 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
)

loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10)


def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_cql_speed(
benchmark, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
):
common = MLP(
num_cells=ncells,
in_features=n_obs,
Expand Down Expand Up @@ -475,7 +565,10 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
distribution_kwargs={"safe_tanh": False},
),
)
value_head = Mod(
Expand All @@ -487,11 +580,22 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_a2c_speed(
benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
benchmark, compile, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
):
common_net = MLP(
num_cells=ncells,
Expand Down Expand Up @@ -533,7 +637,10 @@ def test_a2c_speed(
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
distribution_kwargs={"safe_tanh": False},
),
)
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
Expand All @@ -544,11 +651,22 @@ def test_a2c_speed(
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_ppo_speed(
benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
benchmark, compile, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
):
common_net = MLP(
num_cells=ncells,
Expand Down Expand Up @@ -590,7 +708,10 @@ def test_ppo_speed(
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
distribution_kwargs={"safe_tanh": False},
),
)
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
Expand All @@ -601,11 +722,22 @@ def test_ppo_speed(
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_reinforce_speed(
benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
benchmark, compile, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
):
common_net = MLP(
num_cells=ncells,
Expand Down Expand Up @@ -647,7 +779,10 @@ def test_reinforce_speed(
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
distribution_kwargs={"safe_tanh": False},
),
)
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
Expand All @@ -658,11 +793,22 @@ def test_reinforce_speed(
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_iql_speed(
benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
benchmark, compile, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
):
common_net = MLP(
num_cells=ncells,
Expand Down Expand Up @@ -723,6 +869,16 @@ def test_iql_speed(

loss = IQLLoss(actor_network=actor, value_network=value, qvalue_network=qvalue)
loss(td)

if compile:
if isinstance(compile, str):
loss = torch.compile(loss, mode=compile, fullgraph=True)
else:
loss = torch.compile(loss, fullgraph=True)

loss(td)
loss(td)

benchmark(loss, td)


Expand Down
8 changes: 6 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@
unravel_key,
unravel_key_list,
)
from tensordict._C import _unravel_key_to_tuple
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import expand_as_right, expand_right, NestedKey
from tensordict.utils import (
_unravel_key_to_tuple,
expand_as_right,
expand_right,
NestedKey,
)
from torch import nn, Tensor
from torch.utils._pytree import tree_map

Expand Down
Loading

0 comments on commit 94d44e5

Please sign in to comment.