Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/param reset #328

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ docs/d3rlpy*.rst
docs/modules.rst
docs/references/generated
coverage.xml
.coverage
.coverage*
.mypy_cache
.ipynb_checkpoints
build
dist
/.idea/
*.egg-info
*.DS_Store
73 changes: 73 additions & 0 deletions d3rlpy/algos/qlearning/torch/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from abc import ABCMeta, abstractmethod
from typing import Sequence, List
import torch.nn as nn

from ... import QLearningAlgoBase, QLearningAlgoImplBase
from ....constants import IMPL_NOT_INITIALIZED_ERROR

__all__ = [
"ParameterReset"
]

class QLearningCallback(metaclass=ABCMeta):
@abstractmethod
def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int):
pass


class ParameterReset(QLearningCallback):
def __init__(self, replay_ratio: int, encoder_reset:Sequence[bool],
output_reset:bool, algo:QLearningAlgoBase=None) -> None:
self._replay_ratio = replay_ratio
self._encoder_reset = encoder_reset
self._output_reset = output_reset
self._check = False
if algo is not None:
self._check_layer_resets(algo=algo)


def _get_layers(self, q_func:nn.ModuleList)->List[nn.Module]:
all_modules = {nm:module for (nm, module) in q_func.named_modules()}
q_func_layers = [
*all_modules["_encoder._layers"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuseno assuming you're happy with the general approach of using the epoch_callback to inject the parameter reset functionality - I wondered if you could recommend a better approach for obtaining the encoder and fc layers which follows static typing?

all_modules["_fc"]
]
return q_func_layers

def _check_layer_resets(self, algo:QLearningAlgoBase):
assert algo._impl is not None, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo._impl, QLearningAlgoImplBase)

all_valid_layers = []
for q_func in algo._impl.q_function:
q_func_layers = self._get_layers(q_func)
if len(self._encoder_reset) + 1 != len(q_func_layers):
raise ValueError(
f"""
q_function layers: {q_func_layers};
specified encoder layers: {self._encoder_reset}
"""
)
valid_layers = [
hasattr(layer, 'reset_parameters') for lr, layer in zip(
self._encoder_reset, q_func_layers)
if lr
]
all_valid_layers.append(all(valid_layers))
self._check = all(all_valid_layers)
if not self._check:
raise ValueError(
"Some layer do not contain resettable parameters"
)

def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int):
if not self._check:
self._check_layer_resets(algo=algo)
assert isinstance(algo._impl, QLearningAlgoImplBase)
if epoch % self._replay_ratio == 0:
reset_lst = [*self._encoder_reset, self._output_reset]
for q_func in algo._impl.q_function:
q_func_layers = self._get_layers(q_func)
for lr, layer in zip(reset_lst, q_func_layers):
if lr:
layer.reset_parameters()
42 changes: 37 additions & 5 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
from torch.optim import Optimizer

from ....dataclass_utils import asdict_as_float
from ....dataset import Shape
from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
Expand All @@ -14,10 +15,17 @@
build_squashed_gaussian_distribution,
)
from ....torch_utility import TorchMiniBatch
from .ddpg_impl import DDPGCriticLoss
from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules
from .sac_impl import SACImpl, SACModules

__all__ = ["CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss"]
__all__ = [
"CQLImpl",
"DiscreteCQLImpl",
"CQLModules",
"DiscreteCQLLoss",
"CQLLoss",
]


@dataclasses.dataclass(frozen=True)
Expand All @@ -26,6 +34,12 @@ class CQLModules(SACModules):
alpha_optim: Optional[Optimizer]


@dataclasses.dataclass(frozen=True)
class CQLLoss(DDPGCriticLoss):
td_loss: torch.Tensor
conservative_loss: torch.Tensor


class CQLImpl(SACImpl):
_modules: CQLModules
_alpha_threshold: float
Expand Down Expand Up @@ -65,12 +79,28 @@ def __init__(

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
) -> torch.Tensor:
loss = super().compute_critic_loss(batch, q_tpn)
) -> CQLLoss:
loss = super().compute_critic_loss(batch, q_tpn).loss
conservative_loss = self._compute_conservative_loss(
batch.observations, batch.actions, batch.next_observations
)
return loss + conservative_loss
return CQLLoss(
loss=loss + conservative_loss,
td_loss=loss,
conservative_loss=conservative_loss,
)

def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]:
self._modules.critic_optim.zero_grad()

q_tpn = self.compute_target(batch)

loss = self.compute_critic_loss(batch, q_tpn)

loss.loss.backward()
self._modules.critic_optim.step()

return asdict_as_float(loss)

def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]:
assert self._modules.alpha_optim
Expand Down Expand Up @@ -274,5 +304,7 @@ def compute_loss(
)
loss = td_loss + self._alpha * conservative_loss
return DiscreteCQLLoss(
loss=loss, td_loss=td_loss, conservative_loss=conservative_loss
loss=loss,
td_loss=td_loss,
conservative_loss=self._alpha * conservative_loss,
)
23 changes: 18 additions & 5 deletions d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@
from torch import nn
from torch.optim import Optimizer

from ....dataclass_utils import asdict_as_float
from ....dataset import Shape
from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy
from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync
from ..base import QLearningAlgoImplBase
from .utility import ContinuousQFunctionMixin

__all__ = ["DDPGImpl", "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules"]
__all__ = [
"DDPGImpl",
"DDPGBaseImpl",
"DDPGBaseModules",
"DDPGModules",
"DDPGCriticLoss",
]


@dataclasses.dataclass(frozen=True)
Expand All @@ -24,6 +31,11 @@ class DDPGBaseModules(Modules):
critic_optim: Optimizer


@dataclasses.dataclass(frozen=True)
class DDPGCriticLoss:
loss: torch.Tensor


class DDPGBaseImpl(
ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta
):
Expand Down Expand Up @@ -63,22 +75,23 @@ def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]:

loss = self.compute_critic_loss(batch, q_tpn)

loss.backward()
loss.loss.backward()
self._modules.critic_optim.step()

return {"critic_loss": float(loss.cpu().detach().numpy())}
return asdict_as_float(loss)

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
) -> torch.Tensor:
return self._q_func_forwarder.compute_error(
) -> DDPGCriticLoss:
loss = self._q_func_forwarder.compute_error(
observations=batch.observations,
actions=batch.actions,
rewards=batch.rewards,
target=q_tpn,
terminals=batch.terminals,
gamma=self._gamma**batch.intervals,
)
return DDPGCriticLoss(loss=loss)

def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
# Q function should be inference mode for stability
Expand Down
113 changes: 113 additions & 0 deletions tests/algos/qlearning/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import pytest
from typing import Any, Sequence, List, Union
from unittest.mock import MagicMock, Mock
from d3rlpy.dataset import Shape

from d3rlpy.algos.qlearning.torch.callbacks import ParameterReset
from d3rlpy.algos import QLearningAlgoBase, QLearningAlgoImplBase
from d3rlpy.torch_utility import Modules
import torch

from ...test_torch_utility import DummyModules


class LayerHasResetMock:

def reset_parameters(self):
return True

class LayerNoResetMock:
pass

fc = torch.nn.Linear(100, 100)
optim = torch.optim.Adam(fc.parameters())
modules = DummyModules(fc=fc, optim=optim)

class ImplMock(MagicMock):

def __init__(
self, q_funcs:List[Union[LayerHasResetMock, LayerNoResetMock]]
) -> None:
super().__init__(spec=QLearningAlgoImplBase)
self.q_function = q_funcs


class QLearningAlgoBaseMock(MagicMock):

def __init__(self, spec, layer_setup:Sequence[bool]) -> None:
super().__init__(spec=spec)
q_funcs = []
for i in layer_setup:
if i:
q_funcs.append(LayerHasResetMock())
else:
q_funcs.append(LayerNoResetMock())
self._impl = ImplMock(q_funcs=q_funcs)



def test_check_layer_resets():
algo = QLearningAlgoBaseMock(spec=QLearningAlgoBase,
layer_setup=[True, True, False])
replay_ratio = 2
layer_reset_valid = [True, True, False]
pr = ParameterReset(
replay_ratio=replay_ratio,
layer_reset=layer_reset_valid,
algo=algo
)
assert pr._check is True

layer_reset_invalid = [True, True, True]
try:
pr = ParameterReset(
replay_ratio=replay_ratio,
layer_reset=layer_reset_invalid,
algo=algo
)
raise Exception
except ValueError as e:
assert True

layer_reset_long = [True, True, True, False]
try:
pr = ParameterReset(
replay_ratio=replay_ratio,
layer_reset=layer_reset_long,
algo=algo
)
raise Exception
except ValueError as e:
assert True

layer_reset_shrt = [True, True]
try:
pr = ParameterReset(
replay_ratio=replay_ratio,
layer_reset=layer_reset_shrt,
algo=algo
)
raise Exception
except ValueError as e:
assert True


def test_call():
algo = QLearningAlgoBaseMock(spec=QLearningAlgoBase,
layer_setup=[True, True, False])
replay_ratio = 2
layer_reset_valid = [True, True, False]
pr = ParameterReset(
replay_ratio=replay_ratio,
layer_reset=layer_reset_valid,
algo=algo
)
pr(algo=algo, epoch=1, total_step=100)
pr(algo=algo, epoch=2, total_step=100)

pr = ParameterReset(
replay_ratio=replay_ratio,
layer_reset=layer_reset_valid,
)
pr(algo=algo, epoch=1, total_step=100)
pr(algo=algo, epoch=2, total_step=100)