Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework garage.torch.optimizers #2177

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from garage.experiment import deterministic
from garage.sampler import RaySampler
from garage.torch.algos import PPO as PyTorch_PPO
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer
from garage.torch.policies import GaussianMLPPolicy as PyTorch_GMP
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer
Expand Down Expand Up @@ -45,15 +45,15 @@ def ppo_garage_pytorch(ctxt, env_id, seed):
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)

policy_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=2.5e-4)),
policy,
max_optimization_epochs=10,
minibatch_size=64)
policy_optimizer = MinibatchOptimizer((torch.optim.Adam, dict(lr=2.5e-4)),
policy,
max_optimization_epochs=10,
minibatch_size=64)

vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
minibatch_size=64)
vf_optimizer = MinibatchOptimizer((torch.optim.Adam, dict(lr=2.5e-4)),
value_function,
max_optimization_epochs=10,
minibatch_size=64)

sampler = RaySampler(agents=policy,
envs=env,
Expand Down
8 changes: 4 additions & 4 deletions src/garage/examples/torch/vpg_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.sampler import RaySampler
from garage.sampler import LocalSampler, RaySampler
from garage.torch.algos import VPG
from garage.torch.policies import GaussianMLPPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
Expand Down Expand Up @@ -44,9 +44,9 @@ def vpg_pendulum(ctxt=None, seed=1):
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)

sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length)
sampler = LocalSampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length)

algo = VPG(env_spec=env.spec,
policy=policy,
Expand Down
17 changes: 14 additions & 3 deletions src/garage/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
"""PyTorch-backed modules and algorithms."""
# yapf: disable
from garage.torch._functions import (as_torch_dict, compute_advantages,
expand_var, filter_valids, flatten_batch,
from garage.torch._dtypes import (ObservationBatch, ObservationOrder,
ShuffledOptimizationNotSupported,
observation_batch_to_packed_sequence)
from garage.torch._functions import (as_tensor, as_torch_dict,
compute_advantages, expand_var,
filter_valids, flatten_batch,
flatten_to_single_vector, global_device,
NonLinearity, np_to_torch,
output_height_2d, output_width_2d,
pad_to_last, prefer_gpu,
product_of_gaussians, set_gpu_mode,
soft_update_model, state_dict_to,
torch_to_np, update_module_params)
torch_to_np, update_module_params,
list_to_tensor)

# yapf: enable
__all__ = [
'NonLinearity',
'as_torch_dict',
'as_tensor',
'compute_advantages',
'expand_var',
'filter_valids',
'flatten_batch',
'flatten_to_single_vector',
'global_device',
'list_to_tensor',
'np_to_torch',
'ObservationBatch',
'observation_batch_to_packed_sequence',
'ObservationOrder',
'output_height_2d',
'output_width_2d',
'pad_to_last',
'prefer_gpu',
'product_of_gaussians',
'set_gpu_mode',
'ShuffledOptimizationNotSupported',
'soft_update_model',
'state_dict_to',
'torch_to_np',
Expand Down
170 changes: 170 additions & 0 deletions src/garage/torch/_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Data structures used in garage.torch."""
from dataclasses import dataclass
import enum

import torch
from torch import nn


class ShuffledOptimizationNotSupported(ValueError):
"""Raised by recurrent policies if they're passed a shuffled batch."""


class ObservationOrder(enum.IntEnum):
"""Defines the order of observations in an ObservationBatch.

See :class:`ObservationBatch` for detailed documentation.

"""
# Tensor contains a batch of "most recent" observations.
# This ordering is typcially used when performing rollouts, and it is
# expected that stateful policies maintain there own state when using this
# ordering.
LAST = 0
# Tensor contains observations with timesteps from potentially different
# episodes in a shuffled order. Recurrent policies should raise
# ShuffledOptimizationNotSupported if they encounter this ordering.
SHUFFLED = 1
# Tensor contains all observations for a batch of episodes, in order.
EPISODES = 2


@dataclass(init=False, eq=False)
class ObservationBatch(torch.Tensor):
r"""The (differentiable) input to all pytorch policies.

Args:
observations (torch.Tensor): A torch tensor containing flattened
observations in a batch. Stateless policies should always operate
on this input. This input is passed to the super-constructor.
Shape depends on the order:
* If `order == ROLLOUT`, has shape :math:`(V, O)` (where V is the
vectorization level).
* If `order == SHUFFLED`, has shape :math:`(B, O)` (where B is the
mini-batch size).
* If order == EPISODES, has shape :math:`(N \bullet [T], O)`
(where N is the number of episodes, and T is the episode
lengths).
order (ObservationOrder): The order of observations in this batch. If
this is set to EPISODES, lengths must not be None.
lengths (torch.Tensor or None): Integer tensor containing the lengths
of each episode. Only has a value if `order == EPISODES`.
"""

order: ObservationOrder
lengths: torch.Tensor = None

def __new__(cls, observations, order, lengths=None):
"""Check that lengths is consistent with the rest of the fields.

Raises:
ValueError: If lengths is not consistent with another field.

Returns:
ObservationBatch: A new observation batch.

"""
self = super().__new__(cls, observations)
self.order = order
self.lengths = lengths
if self.order == ObservationOrder.EPISODES:
if self.lengths is None:
raise ValueError(
'lengths is None, but must be a torch.Tensor when '
'order == ObservationOrder.EPISODES')
assert self.lengths is not None
if self.lengths.dtype not in (torch.uint8, torch.int8, torch.int16,
torch.int32, torch.int64):
raise ValueError(
f'lengths has dtype {self.lengths.dtype}, but must have '
f'an integer dtype')
total_size = sum(self.lengths)
if self.shape[0] != total_size:
raise ValueError(
f'observations has batch size '
f'{self.observations.shape[0]}, but must have batch '
f'size {total_size} to match lengths')
assert self.shape[0] == total_size
elif self.lengths is not None:
raise ValueError(
f'lengths has value {self.lengths}, but must be None '
f'when order == {self.order}')
return self

def __repr__(self):
return f'{type(self).__name__}({super().__repr__()}, order={self.order!r}, lengths={self.lengths!r})'

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# print(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
if kwargs is None:
kwargs = {}
result = super().__torch_function__(func, types, args, kwargs)
# Fixup ObservationBatch instances returned from methods.
# In the future this might preserve order for some methods
if isinstance(result, ObservationBatch):
if not hasattr(result, 'order'):
result.order = ObservationOrder.SHUFFLED
if not hasattr(result, 'lengths'):
result.lengths = None
return result


def observation_batch_to_packed_sequence(observations):
"""Turn ObservationBatch into a torch.nn.utils.rnn.PackedSequence.

This function is not a method on ObservationBatch so that it can be called
on a observation Tensor that is not an ObservationBatch. This simplifies
the implementation of recurrent policies.

Args:
observations (torch.Tensor or ObservationBatch): Observations to
convert to PackedSequence.

Raises:
ShuffledOptimizationNotSupported: If called with an input that is not
an ObservationBatch or when `order != EPISODES`

Returns:
torch.nn.utils.rnn.PackedSequence: The sequence of flattened
observations.

"""
if not isinstance(observations, ObservationBatch):
raise ShuffledOptimizationNotSupported(
f'observations should be an ObservationBatch, but was of '
f'type {type(observations)!r} instead.')
if observations.order != ObservationOrder.EPISODES:
raise ShuffledOptimizationNotSupported(
f'order has value {observations.order} but must have order '
f'{ObservationOrder.EPISODES} to use to_packed_sequence')
sequence = []
start = 0
for length in observations.lengths:
stop = start + length
sequence.append(observations.observations[start:stop])
start = stop
pack_sequence = nn.utils.rnn.pack_sequence
return pack_sequence(sequence, enforce_sorted=False)


def is_policy_recurrent(policy, env_spec):
"""Check if a torch policy is recurrent.

Args:
policy (garage.torch.Policy): Policy that might be recurrent.

Returns:
bool: If policy is recurrent.

"""
try:
policy.forward(
as_tensor([
env_spec.observation_space.sample(),
env_spec.observation_space.sample()
]))
except ShuffledOptimizationNotSupported:
return True
else:
return False
56 changes: 56 additions & 0 deletions src/garage/torch/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,49 @@ def compute_advantages(discount, gae_lambda, max_episode_length, baselines,
return advantages


def discount_cumsum(x, discount):
discount_x = torch.full((len(x), ),
discount,
dtype=torch.float,
device=x.device)
# discount_x[0] = 1.0
filter = torch.cumprod(discount_x, dim=0)
returns = F.conv1d(x, filter, stride=1)
assert returns.shape == (len(x), )
from garage.np import discount_cumsum as np_discout_cumsum
import numpy as np
expected = np_discout_cumsum(torch_to_np(x), discount)
assert np.array_equal(expected, torch_to_np(returns))
return returns


def split_packed_tensor(t, lengths):
"""Split a tensor using a sequence of (start, stop) tuples."""
start = 0
for length in lengths:
stop = start + length
yield t[start:stop]
start = stop


def pad_packed_tensor(t, lengths, max_length=None):
if max_length is None:
max_length = max(lengths)
if max(lengths) > max_length:
raise ValueError(f'packed tensor contains a sequence of length '
f'{max(lengths)}, but was asked to pad to '
f'length {max_length}')
out = torch.zeros((
len(lengths),
max_length,
) + t.shape[1:],
dtype=t.dtype,
device=t.device)
for i, seq in enumerate(split_packed_tensor(t, lengths)):
out[i][:len(seq)] = seq
return out


def pad_to_last(nums, total_length, axis=-1, val=0):
"""Pad val to last in nums in given axis.

Expand Down Expand Up @@ -383,6 +426,19 @@ def state_dict_to(state_dict, device):
return state_dict


def as_tensor(data):
"""Convert a list to a PyTorch tensor

Args:
data (list): Data to convert to tensor

Returns:
torch.Tensor: A float tensor

"""
return torch.as_tensor(data, dtype=torch.float32, device=global_device())


# pylint: disable=W0223
class NonLinearity(nn.Module):
"""Wrapper class for non linear function or module.
Expand Down
8 changes: 4 additions & 4 deletions src/garage/torch/algos/maml_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from garage import _Default
from garage.torch.algos import PPO
from garage.torch.algos.maml import MAML
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers import MinibatchOptimizer


class MAMLPPO(MAML):
Expand Down Expand Up @@ -70,10 +70,10 @@ def __init__(self,
meta_evaluator=None,
evaluate_every_n_epochs=1):

policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), policy)
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
value_function)
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), value_function)

inner_algo = PPO(env.spec,
policy,
Expand Down
8 changes: 4 additions & 4 deletions src/garage/torch/algos/maml_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from garage.torch.algos import VPG
from garage.torch.algos.maml import MAML
from garage.torch.optimizers import (ConjugateGradientOptimizer,
OptimizerWrapper)
MinibatchOptimizer)


class MAMLTRPO(MAML):
Expand Down Expand Up @@ -71,10 +71,10 @@ def __init__(self,
meta_evaluator=None,
evaluate_every_n_epochs=1):

policy_optimizer = OptimizerWrapper(
policy_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), policy)
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
value_function)
vf_optimizer = MinibatchOptimizer(
(torch.optim.Adam, dict(lr=inner_lr)), value_function)

inner_algo = VPG(env.spec,
policy,
Expand Down
Loading