Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create optimizer in OnPolicyAlgorithm only after the device is set #1771

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
Changelog
==========

Unreleased
----------

Bug Fixes:
^^^^^^^^^^
- In ``OnPolicyAlgorithm``, delay the initialization of the optimizer until after the correct compute-device type is set (@cmangla)

Release 2.2.1 (2023-11-17)
--------------------------
**Support for options at reset, bug fixes and better error messages**
Expand Down
8 changes: 7 additions & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,15 @@ def _setup_model(self) -> None:
**self.rollout_buffer_kwargs,
)
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
self.observation_space,
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
_init_optimizer=False,
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)
self.policy._build_optimizer(self.lr_schedule)

def collect_rollouts(
self,
Expand Down
23 changes: 20 additions & 3 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ class ActorCriticPolicy(BasePolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param _init_optimizer: If the optimizer should be initialised immediately. True by default.
When set to False, the optimizer should be explicitly initialized using ``._build_optimizer(...)``.
"""

def __init__(
Expand All @@ -462,6 +464,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
_init_optimizer=True,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
Expand Down Expand Up @@ -530,7 +533,7 @@ def __init__(
# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)

self._build(lr_schedule)
self._build(lr_schedule, _init_optimizer)

def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
Expand Down Expand Up @@ -580,7 +583,7 @@ def _build_mlp_extractor(self) -> None:
device=self.device,
)

def _build(self, lr_schedule: Schedule) -> None:
def _build(self, lr_schedule: Schedule, _init_optimizer=True) -> None:
"""
Create the networks and the optimizer.

Expand Down Expand Up @@ -628,8 +631,14 @@ def _build(self, lr_schedule: Schedule) -> None:
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))

if _init_optimizer:
self._build_optimizer(lr_schedule)

def _build_optimizer(self, lr_schedule: Schedule) -> None:
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]
self.optimizer = self.optimizer_class(
self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs
) # type: ignore[call-arg]

def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Expand Down Expand Up @@ -791,6 +800,8 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param _init_optimizer: If the optimizer should be initialised immediately. True by default.
When set to False, the optimizer should be explicitly initialized using ``._build_optimizer(...)``.
"""

def __init__(
Expand All @@ -812,6 +823,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
_init_optimizer=True,
):
super().__init__(
observation_space,
Expand All @@ -831,6 +843,7 @@ def __init__(
normalize_images,
optimizer_class,
optimizer_kwargs,
_init_optimizer,
)


Expand Down Expand Up @@ -864,6 +877,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param _init_optimizer: If the optimizer should be initialised immediately. True by default.
When set to False, the optimizer should be explicitly initialized using ``._build_optimizer(...)``.
"""

def __init__(
Expand All @@ -885,6 +900,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
_init_optimizer=True,
):
super().__init__(
observation_space,
Expand All @@ -904,6 +920,7 @@ def __init__(
normalize_images,
optimizer_class,
optimizer_kwargs,
_init_optimizer,
)


Expand Down
9 changes: 7 additions & 2 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,23 @@ def test_auto_wrap(model_class):
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_predict(model_class, env_id, device):
@pytest.mark.parametrize("policy_kwargs", [None, dict(optimizer_kwargs={"fused": True})])
def test_predict(model_class, env_id, device, policy_kwargs):
if device == "cuda" and not th.cuda.is_available():
pytest.skip("CUDA not available")

if policy_kwargs is not None:
if not (device == "cuda" and model_class in [PPO, A2C]):
pytest.skip("Special 'fused' optimizer only available on PPO/A2C with cuda")

if env_id == "CartPole-v1":
if model_class in [SAC, TD3]:
return
elif model_class in [DQN]:
return

# Test detection of different shapes by the predict method
model = model_class("MlpPolicy", env_id, device=device)
model = model_class("MlpPolicy", env_id, device=device, policy_kwargs=policy_kwargs)
# Check that the policy is on the right device
assert get_device(device).type == model.policy.device.type

Expand Down