Skip to content

Commit

Permalink
WIP torch optimizer refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Jan 19, 2021
1 parent 92646ed commit 017274f
Show file tree
Hide file tree
Showing 17 changed files with 189 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from garage.experiment import deterministic
from garage.sampler import RaySampler
from garage.torch.algos import PPO as PyTorch_PPO
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer
from garage.torch.policies import GaussianMLPPolicy as PyTorch_GMP
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer
Expand Down Expand Up @@ -45,15 +45,15 @@ def ppo_garage_pytorch(ctxt, env_id, seed):
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)

policy_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=2.5e-4)),
policy,
max_optimization_epochs=10,
minibatch_size=64)
policy_optimizer = MinibatchOptimizer((torch.optim.Adam, dict(lr=2.5e-4)),
policy,
max_optimization_epochs=10,
minibatch_size=64)

vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
minibatch_size=64)
vf_optimizer = MinibatchOptimizer((torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
minibatch_size=64)

sampler = RaySampler(agents=policy,
envs=env,
Expand Down
35 changes: 18 additions & 17 deletions src/garage/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
"""PyTorch-backed modules and algorithms."""
# yapf: disable
from garage.torch._dtypes import (ObservationBatch, ObservationOrder,
ShuffledOptimizationNotSupported,
observation_batch_to_packed_sequence)
from garage.torch._functions import (compute_advantages, dict_np_to_torch,
filter_valids, flatten_batch,
flatten_to_single_vector, global_device,
NonLinearity, np_to_torch, pad_to_last,
prefer_gpu, product_of_gaussians,
set_gpu_mode, soft_update_model,
torch_to_np, TransposeImage,
update_module_params)
from garage.torch._dtypes import (observation_batch_to_packed_sequence,
ObservationBatch, ObservationOrder,
ShuffledOptimizationNotSupported)
from garage.torch._functions import (as_tensor, compute_advantages,
dict_np_to_torch, filter_valids,
flatten_batch, flatten_to_single_vector,
global_device, NonLinearity, np_to_torch,
pad_to_last, prefer_gpu,
product_of_gaussians, set_gpu_mode,
soft_update_model, torch_to_np,
TransposeImage, update_module_params)

# yapf: enable
__all__ = [
'compute_advantages', 'dict_np_to_torch', 'filter_valids', 'flatten_batch',
'global_device', 'np_to_torch', 'pad_to_last', 'prefer_gpu',
'product_of_gaussians', 'set_gpu_mode', 'soft_update_model', 'torch_to_np',
'update_module_params', 'NonLinearity', 'flatten_to_single_vector',
'TransposeImage', 'ObservationBatch', 'ObservationOrder',
'ShuffledOptimizationNotSupported', 'observation_batch_to_packed_sequence'
'as_tensor', 'compute_advantages', 'dict_np_to_torch', 'filter_valids',
'flatten_batch', 'global_device', 'np_to_torch', 'pad_to_last',
'prefer_gpu', 'product_of_gaussians', 'set_gpu_mode', 'soft_update_model',
'torch_to_np', 'update_module_params', 'NonLinearity',
'flatten_to_single_vector', 'TransposeImage', 'ObservationBatch',
'ObservationOrder', 'ShuffledOptimizationNotSupported',
'observation_batch_to_packed_sequence'
]
4 changes: 2 additions & 2 deletions src/garage/torch/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def __new__(cls, observations, order, lengths=None):
f'lengths has dtype {self.lengths.dtype}, but must have '
f'an integer dtype')
total_size = sum(self.lengths)
if self.observations.shape[0] != total_size:
if self.shape[0] != total_size:
raise ValueError(
f'observations has batch size '
f'{self.observations.shape[0]}, but must have batch '
f'size {total_size} to match lengths')
assert self.observations.shape[0] == total_size
assert self.shape[0] == total_size
elif self.lengths is not None:
raise ValueError(
f'lengths has value {self.lengths}, but must be None '
Expand Down
25 changes: 19 additions & 6 deletions src/garage/torch/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def discount_cumsum(x, discount):
discount,
dtype=torch.float,
device=x.device)
discount_x[0] = 1.0
filter = torch.cumprod(discount_x, dim=0)
returns = F.conv1d(x, filter, stride=1)
assert returns.shape == (len(x), )
from garage.np import discount_cumsum as np_discout_cumsum
import numpy as np
expected = np_discout_cumsum(torch_to_np(x), discount)
assert np.array_equal(expected, torch_to_np(returns))
pad = len(x) - 1
# minibatch of 1, with 1 channel
filter = filter.reshape(1, 1, -1)
returns = F.conv1d(x.reshape(1, 1, -1), filter, stride=1, padding=pad)
returns = returns[0, 0, pad:]
return returns


Expand Down Expand Up @@ -372,6 +372,19 @@ def product_of_gaussians(mus, sigmas_squared):
return mu, sigma_squared


def as_tensor(data):
"""Convert a list to a PyTorch tensor
Args:
data (list): Data to convert to tensor
Returns:
torch.Tensor: A float tensor
"""
return torch.as_tensor(data, dtype=torch.float32, device=global_device())


# pylint: disable=W0223
class NonLinearity(nn.Module):
"""Wrapper class for non linear function or module.
Expand Down
8 changes: 4 additions & 4 deletions src/garage/torch/algos/maml_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from garage import _Default
from garage.torch.algos import PPO
from garage.torch.algos.maml import MAML
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer


class MAMLPPO(MAML):
Expand Down Expand Up @@ -70,10 +70,10 @@ def __init__(self,
meta_evaluator=None,
evaluate_every_n_epochs=1):

policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), policy)
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
value_function)
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), value_function)

inner_algo = PPO(env.spec,
policy,
Expand Down
8 changes: 4 additions & 4 deletions src/garage/torch/algos/maml_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from garage.torch.algos import VPG
from garage.torch.algos.maml import MAML
from garage.torch.optimizers import (ConjugateGradientOptimizer,
OptimizerWrapper)
MinibatchOptimizer)


class MAMLTRPO(MAML):
Expand Down Expand Up @@ -71,10 +71,10 @@ def __init__(self,
meta_evaluator=None,
evaluate_every_n_epochs=1):

policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), policy)
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
value_function)
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), value_function)

inner_algo = VPG(env.spec,
policy,
Expand Down
8 changes: 4 additions & 4 deletions src/garage/torch/algos/maml_vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from garage import _Default
from garage.torch.algos import VPG
from garage.torch.algos.maml import MAML
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer


class MAMLVPG(MAML):
Expand Down Expand Up @@ -66,10 +66,10 @@ def __init__(self,
num_grad_updates=1,
meta_evaluator=None,
evaluate_every_n_epochs=1):
policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), policy)
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
value_function)
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), value_function)

inner_algo = VPG(env.spec,
policy,
Expand Down
10 changes: 5 additions & 5 deletions src/garage/torch/algos/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from garage.torch.algos import VPG
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer


class PPO(VPG):
Expand All @@ -14,9 +14,9 @@ class PPO(VPG):
value_function (garage.torch.value_functions.ValueFunction): The value
function.
sampler (garage.sampler.Sampler): Sampler.
policy_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer
policy_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer
for policy.
vf_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer for
vf_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer for
value function.
lr_clip_range (float): The limit on the likelihood ratio between
policies.
Expand Down Expand Up @@ -63,13 +63,13 @@ def __init__(self,
entropy_method='no_entropy'):

if policy_optimizer is None:
policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=2.5e-4)),
policy,
max_optimization_epochs=10,
minibatch_size=64)
if vf_optimizer is None:
vf_optimizer = OptimizerWrapper(
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
Expand Down
33 changes: 18 additions & 15 deletions src/garage/torch/algos/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from garage.torch.algos import VPG
from garage.torch.optimizers import (ConjugateGradientOptimizer,
OptimizerWrapper)
MinibatchOptimizer)


class TRPO(VPG):
Expand All @@ -15,9 +15,9 @@ class TRPO(VPG):
value_function (garage.torch.value_functions.ValueFunction): The value
function.
sampler (garage.sampler.Sampler): Sampler.
policy_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer
policy_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer
for policy.
vf_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer for
vf_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer for
value function.
num_train_per_epoch (int): Number of train_once calls per epoch.
discount (float): Discount.
Expand Down Expand Up @@ -61,11 +61,11 @@ def __init__(self,
entropy_method='no_entropy'):

if policy_optimizer is None:
policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(ConjugateGradientOptimizer, dict(max_constraint_value=0.01)),
policy)
if vf_optimizer is None:
vf_optimizer = OptimizerWrapper(
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
Expand Down Expand Up @@ -116,7 +116,8 @@ def _compute_objective(self, advantages, obs, actions, rewards):

return surrogate

def _train_policy(self, obs, actions, rewards, advantages):
def _train_policy(self, observations, actions, rewards, advantages,
lengths):
r"""Train the policy.
Args:
Expand All @@ -128,17 +129,19 @@ def _train_policy(self, obs, actions, rewards, advantages):
with shape :math:`(N, )`.
advantages (torch.Tensor): Advantage value at each step
with shape :math:`(N, )`.
lengths (torch.Tensor): Lengths of episodes.
Returns:
torch.Tensor: Calculated mean scalar value of policy loss (float).
"""
self._policy_optimizer.zero_grad()
loss = self._compute_loss_with_adv(obs, actions, rewards, advantages)
loss.backward()
self._policy_optimizer.step(
f_loss=lambda: self._compute_loss_with_adv(obs, actions, rewards,
advantages),
f_constraint=lambda: self._compute_kl_constraint(obs))

return loss
data = {
'observations': observations,
'actions': actions,
'rewards': rewards,
'advantages': advantages,
'lengths': lengths
}
f_constraint = lambda: self._compute_kl_constraint(observations)
return self._policy_optimizer.step(data, self._loss_function,
f_constraint)
Loading

0 comments on commit 017274f

Please sign in to comment.