Skip to content
This repository has been archived by the owner on Jan 27, 2023. It is now read-only.

Commit

Permalink
Modify clip_and_step to take Params
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Aug 15, 2019
1 parent a825c53 commit efa3d85
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion rainy/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def nstep(self, states: Array[State]) -> Array[State]:
(policy_loss
+ self.config.value_loss_weight * 0.5 * value_loss
- self.config.entropy_weight * entropy_loss).backward()
mpi.clip_and_step(self.net, self.config.grad_clip, self.optimizer)
mpi.clip_and_step(self.net.parameters(), self.config.grad_clip, self.optimizer)
p, v, e = p + policy_loss.item(), v + value_loss.item(), e + entropy_loss.item()

self.lr_cooler.lr_decay(self.optimizer)
Expand Down
3 changes: 2 additions & 1 deletion rainy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import Callable, Dict, List, Optional, Tuple
from .envs import ClassicalControl, DummyParallelEnv, EnvExt, EnvGen, ParallelEnv
from .net import actor_critic, option_critic, value
from .net.prelude import NetFn, Params
from .net.prelude import NetFn
from .lib.explore import DummyCooler, Cooler, LinearCooler, Explorer, EpsGreedy
from .lib import mpi
from .lib.kfac import KfacPreConditioner, PreConditioner
from .prelude import Params
from .replay import DqnReplayFeed, ReplayBuffer, UniformReplayBuffer
from .utils import Device, DummyLogger, Logger

Expand Down
2 changes: 1 addition & 1 deletion rainy/lib/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.optim import Optimizer, SGD
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
import warnings
from ..net.prelude import Params
from ..prelude import Params


class Layer(Enum):
Expand Down
10 changes: 5 additions & 5 deletions rainy/lib/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import Tensor
from torch.optim import Optimizer
from typing import Tuple
from ..prelude import Array
from ..prelude import Array, Params
try:
import horovod.torch as hvd
hvd.init()
Expand All @@ -19,9 +19,9 @@ def setup_optimizer(opt: Optimizer) -> Optimizer:
hvd.broadcast_optimizer_state(opt, root_rank=0)
return hvd.DistributedOptimizer(opt)

def clip_and_step(model: torch.nn.Module, max_norm: float, opt: Optimizer) -> None:
def clip_and_step(params: Params, max_norm: float, opt: Optimizer) -> None:
opt.synchronize()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
torch.nn.utils.clip_grad_norm_(params, max_norm)
with opt.skip_synchronize():
opt.step()

Expand Down Expand Up @@ -63,8 +63,8 @@ def broadcast_model(model: torch.nn.Module) -> None:
def local_size_and_rank() -> Tuple[int, int]:
return 1, 0

def clip_and_step(model: torch.nn.Module, max_norm: float, opt: Optimizer) -> None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
def clip_and_step(params: Params, max_norm: float, opt: Optimizer) -> None:
torch.nn.utils.clip_grad_norm_(params, max_norm)
opt.step()

def global_size() -> int:
Expand Down
5 changes: 2 additions & 3 deletions rainy/net/prelude.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torch import nn, Tensor
from typing import Callable, Iterable, Tuple, Union
from torch import nn
from typing import Callable, Tuple
from ..utils.device import Device

NetFn = Callable[[Tuple[int, ...], int, Device], nn.Module]
Params = Union[Iterable[Tensor], dict]
3 changes: 2 additions & 1 deletion rainy/prelude.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import Tensor
from typing import Any, List, Sequence, Tuple, TypeVar, Union
from typing import Any, Iterable, List, Sequence, Tuple, TypeVar, Union


try:
Expand All @@ -14,6 +14,7 @@ class GenericNamedMeta(NamedTupleMeta, GenericMeta):
T = TypeVar('T')
Self = Any
Index = Union[None, int, slice, Tensor, List[Any], Tuple[Any, ...]]
Params = Union[Iterable[Tensor], dict]


class Array(Sequence[T]):
Expand Down

0 comments on commit efa3d85

Please sign in to comment.