Skip to content

Commit

Permalink
Torch VPG rework
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Jul 16, 2022
1 parent 19e4dbb commit df3a137
Show file tree
Hide file tree
Showing 17 changed files with 289 additions and 277 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from garage.experiment import deterministic
from garage.sampler import RaySampler
from garage.torch.algos import PPO as PyTorch_PPO
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer
from garage.torch.policies import GaussianMLPPolicy as PyTorch_GMP
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer
Expand Down Expand Up @@ -45,15 +45,15 @@ def ppo_garage_pytorch(ctxt, env_id, seed):
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)

policy_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=2.5e-4)),
policy,
max_optimization_epochs=10,
minibatch_size=64)
policy_optimizer = MinibatchOptimizer((torch.optim.Adam, dict(lr=2.5e-4)),
policy,
max_optimization_epochs=10,
minibatch_size=64)

vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
minibatch_size=64)
vf_optimizer = MinibatchOptimizer((torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
minibatch_size=64)

sampler = RaySampler(agents=policy,
envs=env,
Expand Down
6 changes: 4 additions & 2 deletions src/garage/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from garage.torch._dtypes import (ObservationBatch, ObservationOrder,
ShuffledOptimizationNotSupported,
observation_batch_to_packed_sequence)
from garage.torch._functions import (as_torch_dict, compute_advantages,
expand_var, filter_valids, flatten_batch,
from garage.torch._functions import (as_tensor, as_torch_dict,
compute_advantages, expand_var,
filter_valids, flatten_batch,
flatten_to_single_vector, global_device,
NonLinearity, np_to_torch,
output_height_2d, output_width_2d,
Expand All @@ -18,6 +19,7 @@
__all__ = [
'NonLinearity',
'as_torch_dict',
'as_tensor',
'compute_advantages',
'expand_var',
'filter_valids',
Expand Down
4 changes: 2 additions & 2 deletions src/garage/torch/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def __new__(cls, observations, order, lengths=None):
f'lengths has dtype {self.lengths.dtype}, but must have '
f'an integer dtype')
total_size = sum(self.lengths)
if self.observations.shape[0] != total_size:
if self.shape[0] != total_size:
raise ValueError(
f'observations has batch size '
f'{self.observations.shape[0]}, but must have batch '
f'size {total_size} to match lengths')
assert self.observations.shape[0] == total_size
assert self.shape[0] == total_size
elif self.lengths is not None:
raise ValueError(
f'lengths has value {self.lengths}, but must be None '
Expand Down
55 changes: 55 additions & 0 deletions src/garage/torch/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,48 @@ def compute_advantages(discount, gae_lambda, max_episode_length, baselines,
return advantages


def discount_cumsum(x, discount):
discount_x = torch.full((len(x), ),
discount,
dtype=torch.float,
device=x.device)
discount_x[0] = 1.0
filter = torch.cumprod(discount_x, dim=0)
pad = len(x) - 1
# minibatch of 1, with 1 channel
filter = filter.reshape(1, 1, -1)
returns = F.conv1d(x.reshape(1, 1, -1), filter, stride=1, padding=pad)
returns = returns[0, 0, pad:]
return returns


def split_packed_tensor(t, lengths):
"""Split a tensor using a sequence of (start, stop) tuples."""
start = 0
for length in lengths:
stop = start + length
yield t[start:stop]
start = stop


def pad_packed_tensor(t, lengths, max_length=None):
if max_length is None:
max_length = max(lengths)
if max(lengths) > max_length:
raise ValueError(f'packed tensor contains a sequence of length '
f'{max(lengths)}, but was asked to pad to '
f'length {max_length}')
out = torch.zeros((
len(lengths),
max_length,
) + t.shape[1:],
dtype=t.dtype,
device=t.device)
for i, seq in enumerate(split_packed_tensor(t, lengths)):
out[i][:len(seq)] = seq
return out


def pad_to_last(nums, total_length, axis=-1, val=0):
"""Pad val to last in nums in given axis.
Expand Down Expand Up @@ -383,6 +425,19 @@ def state_dict_to(state_dict, device):
return state_dict


def as_tensor(data):
"""Convert a list to a PyTorch tensor
Args:
data (list): Data to convert to tensor
Returns:
torch.Tensor: A float tensor
"""
return torch.as_tensor(data, dtype=torch.float32, device=global_device())


# pylint: disable=W0223
class NonLinearity(nn.Module):
"""Wrapper class for non linear function or module.
Expand Down
8 changes: 4 additions & 4 deletions src/garage/torch/algos/maml_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from garage import _Default
from garage.torch.algos import PPO
from garage.torch.algos.maml import MAML
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer


class MAMLPPO(MAML):
Expand Down Expand Up @@ -70,10 +70,10 @@ def __init__(self,
meta_evaluator=None,
evaluate_every_n_epochs=1):

policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), policy)
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
value_function)
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), value_function)

inner_algo = PPO(env.spec,
policy,
Expand Down
8 changes: 4 additions & 4 deletions src/garage/torch/algos/maml_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from garage.torch.algos import VPG
from garage.torch.algos.maml import MAML
from garage.torch.optimizers import (ConjugateGradientOptimizer,
OptimizerWrapper)
MinibatchOptimizer)


class MAMLTRPO(MAML):
Expand Down Expand Up @@ -71,10 +71,10 @@ def __init__(self,
meta_evaluator=None,
evaluate_every_n_epochs=1):

policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), policy)
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
value_function)
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), value_function)

inner_algo = VPG(env.spec,
policy,
Expand Down
8 changes: 4 additions & 4 deletions src/garage/torch/algos/maml_vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from garage import _Default
from garage.torch.algos import VPG
from garage.torch.algos.maml import MAML
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer


class MAMLVPG(MAML):
Expand Down Expand Up @@ -66,10 +66,10 @@ def __init__(self,
num_grad_updates=1,
meta_evaluator=None,
evaluate_every_n_epochs=1):
policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), policy)
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
value_function)
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), value_function)

inner_algo = VPG(env.spec,
policy,
Expand Down
10 changes: 5 additions & 5 deletions src/garage/torch/algos/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from garage.torch.algos import VPG
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer


class PPO(VPG):
Expand All @@ -14,9 +14,9 @@ class PPO(VPG):
value_function (garage.torch.value_functions.ValueFunction): The value
function.
sampler (garage.sampler.Sampler): Sampler.
policy_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer
policy_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer
for policy.
vf_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer for
vf_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer for
value function.
lr_clip_range (float): The limit on the likelihood ratio between
policies.
Expand Down Expand Up @@ -63,13 +63,13 @@ def __init__(self,
entropy_method='no_entropy'):

if policy_optimizer is None:
policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=2.5e-4)),
policy,
max_optimization_epochs=10,
minibatch_size=64)
if vf_optimizer is None:
vf_optimizer = OptimizerWrapper(
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
Expand Down
2 changes: 1 addition & 1 deletion src/garage/torch/algos/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def __init__(
replay_buffer,
sampler,
*, # Everything after this is numbers.
max_episode_length_eval=None,
grad_steps_per_env_step,
exploration_policy,
max_episode_length_eval=None,
uniform_random_policy=None,
max_action=None,
target_update_tau=0.005,
Expand Down
34 changes: 18 additions & 16 deletions src/garage/torch/algos/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from garage.torch._functions import zero_optim_grads
from garage.torch.algos import VPG
from garage.torch.optimizers import (ConjugateGradientOptimizer,
OptimizerWrapper)
MinibatchOptimizer)


class TRPO(VPG):
Expand All @@ -16,9 +16,9 @@ class TRPO(VPG):
value_function (garage.torch.value_functions.ValueFunction): The value
function.
sampler (garage.sampler.Sampler): Sampler.
policy_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer
policy_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer
for policy.
vf_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer for
vf_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer for
value function.
num_train_per_epoch (int): Number of train_once calls per epoch.
discount (float): Discount.
Expand Down Expand Up @@ -62,11 +62,11 @@ def __init__(self,
entropy_method='no_entropy'):

if policy_optimizer is None:
policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(ConjugateGradientOptimizer, dict(max_constraint_value=0.01)),
policy)
if vf_optimizer is None:
vf_optimizer = OptimizerWrapper(
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
Expand Down Expand Up @@ -117,7 +117,8 @@ def _compute_objective(self, advantages, obs, actions, rewards):

return surrogate

def _train_policy(self, obs, actions, rewards, advantages):
def _train_policy(self, observations, actions, rewards, advantages,
lengths):
r"""Train the policy.
Args:
Expand All @@ -129,18 +130,19 @@ def _train_policy(self, obs, actions, rewards, advantages):
with shape :math:`(N, )`.
advantages (torch.Tensor): Advantage value at each step
with shape :math:`(N, )`.
lengths (torch.Tensor): Lengths of episodes.
Returns:
torch.Tensor: Calculated mean scalar value of policy loss (float).
"""
# pylint: disable=protected-access
zero_optim_grads(self._policy_optimizer._optimizer)
loss = self._compute_loss_with_adv(obs, actions, rewards, advantages)
loss.backward()
self._policy_optimizer.step(
f_loss=lambda: self._compute_loss_with_adv(obs, actions, rewards,
advantages),
f_constraint=lambda: self._compute_kl_constraint(obs))

return loss
data = {
'observations': observations,
'actions': actions,
'rewards': rewards,
'advantages': advantages,
'lengths': lengths
}
f_constraint = lambda: self._compute_kl_constraint(observations)
return self._policy_optimizer.step(data, self._loss_function,
f_constraint)
Loading

0 comments on commit df3a137

Please sign in to comment.