Skip to content

Commit

Permalink
Implement garage.torch.GaussianLSTMPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Jul 3, 2022
1 parent 9985301 commit 19e4dbb
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/garage/torch/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from garage.torch.policies.discrete_cnn_policy import DiscreteCNNPolicy
from garage.torch.policies.discrete_qf_argmax_policy import (
DiscreteQFArgmaxPolicy)
from garage.torch.policies.gaussian_lstm_policy import GaussianLSTMPolicy
from garage.torch.policies.gaussian_mlp_policy import GaussianMLPPolicy
from garage.torch.policies.policy import Policy
from garage.torch.policies.tanh_gaussian_mlp_policy import (
Expand All @@ -21,4 +22,5 @@
'Policy',
'TanhGaussianMLPPolicy',
'ContextConditionedPolicy',
'GaussianLSTMPolicy',
]
228 changes: 228 additions & 0 deletions src/garage/torch/policies/gaussian_lstm_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""GaussianLSTMPolicy."""
import torch
from torch import nn
from torch.distributions import Normal

from garage.torch import (observation_batch_to_packed_sequence,
ObservationBatch, ObservationOrder,
ShuffledOptimizationNotSupported)
from garage.torch.modules import GaussianMLPModule
from garage.torch.policies.stochastic_policy import StochasticPolicy


class GaussianLSTMPolicy(StochasticPolicy):
"""LSTM whose outputs are fed into a Normal distribution..
A policy that contains a LSTM to make prediction based on a gaussian
distribution.
Args:
env_spec (EnvSpec): Environment specification.
name (str): Name of policy.
hidden_size (int): Hidden dimension for LSTM cell for mean.
num_layers (int): Number of recurrent layers.
hidden_mlp_nonlinearity (Callable): Activation function for
intermediate dense layer(s). It should return a torch.Tensor. Set
it to None to maintain a linear activation.
hidden_mlp_sizes (list[int]): Output dimension of dense layer(s) for
the MLP for mean. For example, (32, 32) means the MLP consists
of two hidden layers, each with 32 hidden units.
hidden_mlp_w_init (Callable): Initializer function for the weight
of intermediate dense layer(s). Should modify a torch.Tensor.
hidden_mlp_b_init (Callable): Initializer function for the bias
of intermediate dense layer(s). Should modify a torch.Tensor.
output_nonlinearity (callable): Activation function for output dense
layer. It should return a torch.Tensor. Set it to None to
maintain a linear activation.
output_w_init (Callable): Initializer function for the weight
of output dense layer(s). Should modify a torch.Tensor.
output_b_init (Callable): Initializer function for the bias
of output dense layer(s). Should modify a torch.Tensor.
recurrent_w_init (Callable): Initializer function for the weight
of recurrent layer(s). Should modify a torch.Tensor.
hidden_state_init (Callable): Initializer function for the
initial hidden state. Should modify a torch.Tensor.
hidden_state_init_trainable (bool): Bool for whether the initial
hidden state is trainable.
cell_state_init (Callable): Initializer function for the
initial cell state. Should modify a torch.Tensor.
cell_state_init_trainable (bool): Bool for whether the initial
cell state is trainable.
learn_std (bool): Is std trainable.
init_std (float): Initial value for std.
min_std (float): Minimum value for std.
max_std (float): Maximum value for std.
std_parameterization (str): How the std should be parametrized. There
are two options:
- exp: the logarithm of the std will be stored, and applied a
exponential transformation
- softplus: the std will be computed as log(1+exp(x))
layer_normalization (bool): Bool for using layer normalization or not.
std_parameterization (str): How the std should be parametrized. There
are two options:
- exp: the logarithm of the std will be stored, and applied a
exponential transformation.
- softplus: the std will be computed as log(1+exp(x)).
normal_distribution_cls (torch.distribution): normal distribution class
to be constructed and returned by a call to forward. By default, is
`torch.distributions.Normal`.
"""

def __init__(self,
env_spec,
*,
name='GaussianLSTMPolicy',
hidden_size=32,
num_layers=2,
cell_state_init_trainable=True,
cell_state_init=None,
hidden_state_init=None,
hidden_state_init_trainable=True,
recurrent_w_init=None,
hidden_mlp_sizes=(32, ),
hidden_mlp_w_init=nn.init.xavier_uniform_,
hidden_mlp_b_init=nn.init.zeros_,
hidden_mlp_nonlinearity=torch.tanh,
output_nonlinearity=None,
output_w_init=nn.init.xavier_uniform_,
output_b_init=nn.init.zeros_,
learn_std=True,
init_std=1.0,
min_std=1e-6,
max_std=None,
std_parameterization='exp',
layer_normalization=False,
normal_distribution_cls=Normal):
super().__init__(env_spec, name)

if std_parameterization not in ('exp', 'softplus'):
raise NotImplementedError

self._obs_dim = env_spec.observation_space.flat_dim
self._action_dim = env_spec.action_space.flat_dim
self._hidden_size = hidden_size
self._num_layers = num_layers
h0 = torch.zeros(num_layers, hidden_size)
c0 = torch.zeros(num_layers, hidden_size)
if cell_state_init is not None:
hidden_state_init(h0)
cell_state_init(c0)
if cell_state_init_trainable:
self._c0 = nn.Parameter(c0)
else:
self._c0 = c0
self.register_buffer('_c0', self._c0)
if hidden_state_init_trainable:
self._h0 = nn.Parameter(h0)
else:
self._h0 = h0
self.register_buffer('_h0', self._h0)
self._rnn = nn.LSTM(input_size=self._obs_dim,
hidden_size=hidden_size,
batch_first=False,
num_layers=num_layers)
if recurrent_w_init is not None:
recurrent_w_init(self._rnn)
self._mlp = GaussianMLPModule(
input_dim=hidden_size,
output_dim=self._action_dim,
hidden_sizes=hidden_mlp_sizes,
hidden_nonlinearity=hidden_mlp_nonlinearity,
hidden_w_init=hidden_mlp_w_init,
hidden_b_init=hidden_mlp_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
learn_std=learn_std,
init_std=init_std,
min_std=min_std,
max_std=max_std,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
normal_distribution_cls=normal_distribution_cls)
self._state = None

def _new_state(self, n_envs):
"""Compute a new state for running n_envs in parallel.
Args:
n_envs (int): Number of observations and actions each call to
get_actions().
Returns:
tuple[torch.Tensor, torch.Tensor]: Tensor containing h0 and c0 with
"batch_dim" repeated n_envs times.
"""
h0 = self._h0.unsqueeze(1).expand(
(self._num_layers, n_envs, self._hidden_size))
c0 = self._c0.unsqueeze(1).expand(
(self._num_layers, n_envs, self._hidden_size))
return h0, c0

def reset(self, do_resets=None):
"""Reset the policy.
This is effective only to recurrent policies.
do_resets is an array of boolean indicating
which internal states to be reset. The length of do_resets should be
equal to the length of inputs, i.e. batch size.
Args:
do_resets (numpy.ndarray): Bool array indicating which states
to be reset.
"""
if do_resets is None:
do_resets = [True]
h0, c0 = self._new_state(len(do_resets))
if all(do_resets):
self._state = (h0, c0)
for i, do_reset in enumerate(do_resets):
if do_reset:
# Reset all layer's state
self._state[0][:, i] = h0[:, i]
self._state[1][:, i] = c0[:, i]

def forward(self, observations):
"""Compute the action distributions from the observations.
Args:
observations (torch.Tensor): Batch of observations on default
torch device.
Raises:
ValueError: If observations is not consistent with reset().
ShuffledOptimizationNotSupported: If passed a shuffled
ObservationBatch or a tensor that is not an ObservationBatch.
Returns:
torch.distributions.Distribution: Batch distribution of actions.
dict[str, torch.Tensor]: Additional agent_info, as torch Tensors
"""
if not isinstance(observations, ObservationBatch):
raise ShuffledOptimizationNotSupported(
f'observations are of type {type(observations)!r}, but should '
f'be an ObservationBatch')
if observations.order == ObservationOrder.LAST:
if self._state is None:
raise ValueError('get_action() called before reset()')
if self._state[0].shape[1] != len(observations):
raise ValueError(f'observations has length '
f'{len(observations)} but should have length '
f'{len(self._state[0])} to match the length '
f'of do_resets in reset()')
# Add sequence dimension.
rnn_out, self._state = self._rnn(observations.unsqueeze(0),
self._state)
else:
sequence = observation_batch_to_packed_sequence(observations)
n_episodes = len(observations.lengths)
start = self._new_state(n_episodes)
rnn_out, _ = self._rnn(sequence, start)
# Remove sequence dimension.
dist = self._mlp(rnn_out.squeeze(0))
return (dist, dict(mean=dist.mean, log_std=(dist.variance**.5).log()))
67 changes: 67 additions & 0 deletions tests/garage/torch/policies/test_gaussian_lstm_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pickle

import numpy as np
import pytest

from garage.envs import GymEnv
from garage.torch.policies import GaussianLSTMPolicy

# yapf: disable
from tests.fixtures.envs.dummy import DummyBoxEnv, DummyDictEnv

# yapf: enable


def test_get_action_dict_space():
env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box'))
policy = GaussianLSTMPolicy(env_spec=env.spec, hidden_size=4)
policy.reset(do_resets=None)
obs = env.reset()[0]

action, _ = policy.get_action(obs)
assert env.action_space.contains(action)

policy.reset(do_resets=[True, True])

actions, _ = policy.get_actions([obs, obs])
for action in actions:
assert env.action_space.contains(action)


# yapf: disable
@pytest.mark.parametrize('obs_dim, action_dim, hidden_size', [
((1, ), (1, ), 4),
((2, ), (2, ), 4),
((1, 1), (1, ), 4),
((2, 2), (2, ), 4)
])
# yapf: enable
def test_get_action(obs_dim, action_dim, hidden_size):
env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
policy = GaussianLSTMPolicy(env_spec=env.spec, hidden_size=hidden_size)

policy.reset()
obs = env.reset()[0]

action, _ = policy.get_action(obs.flatten())
assert env.action_space.contains(action)

actions, _ = policy.get_actions([obs.flatten()])
for action in actions:
assert env.action_space.contains(action)


# pylint: disable=no-member
def test_is_pickleable():
env = GymEnv(DummyBoxEnv(obs_dim=(1, ), action_dim=(1, )))
policy = GaussianLSTMPolicy(env_spec=env.spec)
policy.reset()
obs = env.reset()[0]
_, info = policy.get_action(obs)

p = pickle.dumps(policy)
policy_pickled = pickle.loads(p)
policy_pickled.reset()
_, info2 = policy_pickled.get_action(obs)
assert np.array_equal(info['mean'], info2['mean'])
assert np.array_equal(info['log_std'], info2['log_std'])

0 comments on commit 19e4dbb

Please sign in to comment.