Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding LSTM support to pretrain #315

Open
wants to merge 55 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a9f7e30
mode pretrain in base_calse
Apr 11, 2019
c708a37
add mask support to DataLoader
Apr 11, 2019
169a80b
add better data split
Apr 13, 2019
b0ee4c7
add comments and fix some bugs
Apr 13, 2019
027891f
when using dataset.get_next_batch to expect fore returns.
Apr 13, 2019
b7541bd
update _get_pretrain_placeholders in all models.
Apr 13, 2019
11a9d00
make it work.
May 8, 2019
004e3fb
Merge pull request #1 from hill-a/master
XMaster96 May 8, 2019
4d14812
Merge branch 'master' into LSTM-pretrain
May 8, 2019
167a337
Make it work with 2.5.1
May 8, 2019
187f16e
Merge branch 'master' into LSTM-pretrain
XMaster96 May 8, 2019
1827b2d
improve the syntax
May 8, 2019
34e7bde
Merge remote-tracking branch 'origin/LSTM-pretrain' into LSTM-pretrain
May 8, 2019
54e5c01
Merge branch 'master' into LSTM-pretrain
araffin May 14, 2019
c43b39a
-fix partial_minibatch for LSTMs
May 14, 2019
c7b795a
-fix data alignment for LSTMs
May 15, 2019
920ac7b
Merge remote-tracking branch 'origin/LSTM-pretrain' into LSTM-pretrain
May 15, 2019
ccddbb2
Delete __init__.py
XMaster96 May 15, 2019
ee29e78
Delete run_atari.py
XMaster96 May 15, 2019
938a4f8
Delete run_mujoco.py
XMaster96 May 15, 2019
a952d02
Delete ppo2.py
XMaster96 May 15, 2019
a2a94ad
-fix syntax line length.
May 15, 2019
58ddd30
Merge remote-tracking branch 'origin/LSTM-pretrain' into LSTM-pretrain
May 15, 2019
77538da
-fix syntax
May 15, 2019
a124dcb
-fix syntax
May 15, 2019
c4d9c47
remove nano.save
May 20, 2019
d75c01e
Merge branch 'master' into LSTM-pretrain
XMaster96 May 20, 2019
4d87f41
Merge branch 'master' into LSTM-pretrain
XMaster96 Jun 5, 2019
6ab6728
Merge branch 'master' into LSTM-pretrain
araffin Jun 27, 2019
b9e0fc0
Merge pull request #2 from hill-a/master
XMaster96 Jul 20, 2019
fa4bbcf
Merge branch 'master' into LSTM-pretrain
Jul 20, 2019
b30413e
split LSTM dataset from Expert dataset.
Jul 20, 2019
40a94ad
-fix syntax
Jul 20, 2019
cba1030
Merge branch 'master' into LSTM-pretrain
XMaster96 Aug 7, 2019
9ed3cfa
add TD3 support
Aug 7, 2019
96567be
Merge branch 'master' into LSTM-pretrain
araffin Aug 23, 2019
4bfb988
-fix indentation
Sep 2, 2019
30bdb19
-fix syntax
Sep 2, 2019
06ba9da
-fix syntax
Sep 3, 2019
a08f420
-fix syntax
Sep 3, 2019
17c04ac
Merge branch 'master' into LSTM-pretrain
XMaster96 Sep 3, 2019
1406af3
Merge branch 'master' into LSTM-pretrain
XMaster96 Sep 6, 2019
f448eea
Merge branch 'master' into LSTM-pretrain
araffin Sep 7, 2019
9770a12
Merge branch 'master' into LSTM-pretrain
araffin Sep 12, 2019
eda5b8f
-change
Sep 13, 2019
428022e
Merge branch 'LSTM-pretrain' of https://github.com/XMaster96/stable-b…
Sep 13, 2019
694e6c1
- change
Sep 13, 2019
f11f732
- change
Sep 13, 2019
d8685e4
- fix model save
Sep 14, 2019
e03be1d
-fix syntax
Sep 14, 2019
2f6da05
-fix syntax
Sep 14, 2019
570b8d9
-fix syntax
Sep 14, 2019
e0bb120
-fix pickle load
Sep 14, 2019
2ee1300
Merge branch 'master' into LSTM-pretrain
XMaster96 Sep 16, 2019
988ba5c
Merge branch 'master' into LSTM-pretrain
XMaster96 Sep 21, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions stable_baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,20 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.

def _get_pretrain_placeholders(self):
policy = self.train_model

if self.initial_state is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do:

states_ph, snew_ph, dones_ph = None, None, None

so it's more compact, same for the else case

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having the variable declaration vertical and horizontal interrupts the read flow. Yes It would make the code shorter but also les readable in my opinion. But I will change it if you really wont it that way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should rather check the recurrent attribute of the policy, it is in the base policy class

states_ph = None
snew_ph = None
dones_ph = None
else:
states_ph = policy.states_ph
snew_ph = policy.snew
dones_ph = policy.dones_ph

if isinstance(self.action_space, gym.spaces.Discrete):
return policy.obs_ph, self.actions_ph, policy.policy
return policy.obs_ph, self.actions_ph, policy.deterministic_action
return policy.obs_ph, self.actions_ph, states_ph, snew_ph, dones_ph, policy.policy
return policy.obs_ph, self.actions_ph, states_ph, snew_ph, dones_ph,\
policy.deterministic_action

def setup_model(self):
with SetVerbosity(self.verbose):
Expand Down
12 changes: 11 additions & 1 deletion stable_baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,18 @@ def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5,
def _get_pretrain_placeholders(self):
policy = self.step_model
action_ph = policy.pdtype.sample_placeholder([None])

if self.initial_state is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same remark as before

states_ph = None
snew_ph = None
dones_ph = None
else:
states_ph = policy.states_ph
snew_ph = policy.snew
dones_ph = policy.dones_ph

if isinstance(self.action_space, Discrete):
return policy.obs_ph, action_ph, policy.policy
return policy.obs_ph, action_ph, states_ph, snew_ph, dones_ph, policy.policy
raise NotImplementedError('Only discrete actions are supported for ACER for now')

def set_env(self, env):
Expand Down
12 changes: 11 additions & 1 deletion stable_baselines/acktr/acktr_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,18 @@ def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01,

def _get_pretrain_placeholders(self):
policy = self.train_model

if self.initial_state is None:
states_ph = None
snew_ph = None
dones_ph = None
else:
states_ph = policy.states_ph
snew_ph = policy.snew
dones_ph = policy.dones_ph

if isinstance(self.action_space, Discrete):
return policy.obs_ph, self.action_ph, policy.policy
return policy.obs_ph, self.action_ph, states_ph, snew_ph, dones_ph, policy.policy
raise NotImplementedError("WIP: ACKTR does not support Continuous actions yet.")

def setup_model(self):
Expand Down
47 changes: 41 additions & 6 deletions stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, pol
self.sess = None
self.params = None
self._param_load_ops = None
self.initial_state = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need that variable, there is the recurrent attribute for that

self.n_batch = None
self.nminibatches = None
self.n_steps = None

if env is not None:
if isinstance(env, str):
Expand Down Expand Up @@ -246,13 +250,24 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4,
else:
val_interval = int(n_epochs / 10)

use_lstm = self.initial_state is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same remark, you can use the recurrent attribute


if use_lstm:
if self.nminibatches is None:
envs_per_batch = self.n_envs * self.n_steps
else:
batch_size = self.n_batch // self.nminibatches
envs_per_batch = batch_size // self.n_steps

with self.graph.as_default():
with tf.variable_scope('pretrain'):
if continuous_actions:
obs_ph, actions_ph, deterministic_actions_ph = self._get_pretrain_placeholders()
obs_ph, actions_ph, states_ph, snew_ph, dones_ph, \
deterministic_actions_ph = self._get_pretrain_placeholders()
loss = tf.reduce_mean(tf.square(actions_ph - deterministic_actions_ph))
else:
obs_ph, actions_ph, actions_logits_ph = self._get_pretrain_placeholders()
obs_ph, actions_ph, states_ph, snew_ph, dones_ph, \
actions_logits_ph = self._get_pretrain_placeholders()
# actions_ph has a shape if (n_batch,), we reshape it to (n_batch, 1)
# so no additional changes is needed in the dataloader
actions_ph = tf.expand_dims(actions_ph, axis=1)
Expand All @@ -272,13 +287,23 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4,

for epoch_idx in range(int(n_epochs)):
train_loss = 0.0
if use_lstm:
state = self.initial_state[:envs_per_batch]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initial state is an attribute of the policy

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes and no.
All the models Which can use LSTM policies have the Variable self.initial_state, Which gets set to the initial state from policy. The variable self.initial_state gets used and not the one in the policy. It is also not that easy, to access the initial state from the BaseRLModel. It wars much simpler to at the self.initial_state variable to the Base Model, and then let is overwrite later at model initialization.


# Full pass on the training set
for _ in range(len(dataset.train_loader)):
expert_obs, expert_actions = dataset.get_next_batch('train')
expert_obs, expert_actions, expert_mask = dataset.get_next_batch('train')
feed_dict = {
obs_ph: expert_obs,
actions_ph: expert_actions,
}

if use_lstm:
feed_dict.update({states_ph: state, dones_ph: expert_mask})
state, train_loss_, _ = self.sess.run([snew_ph, loss, optim_op], feed_dict)
else:
train_loss_, _ = self.sess.run([loss, optim_op], feed_dict)

train_loss_, _ = self.sess.run([loss, optim_op], feed_dict)
train_loss += train_loss_

Expand All @@ -288,9 +313,19 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4,
val_loss = 0.0
# Full pass on the validation set
for _ in range(len(dataset.val_loader)):
expert_obs, expert_actions = dataset.get_next_batch('val')
val_loss_, = self.sess.run([loss], {obs_ph: expert_obs,
actions_ph: expert_actions})
expert_obs, expert_actions, expert_mask = dataset.get_next_batch('val')

feed_dict = {
obs_ph: expert_obs,
actions_ph: expert_actions,
}

if use_lstm:
feed_dict.update({states_ph: state, dones_ph: expert_mask})
val_loss_, = self.sess.run([loss], feed_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only need to update the feeddict, self.sess.run can be called outside, so you avoid code duplication

else:
val_loss_, = self.sess.run([loss], feed_dict)

val_loss += val_loss_

val_loss /= len(dataset.val_loader)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _get_pretrain_placeholders(self):
policy = self.policy_tf
# Rescale
deterministic_action = self.actor_tf * np.abs(self.action_space.low)
return policy.obs_ph, self.actions, deterministic_action
return policy.obs_ph, self.actions, None, None, None, deterministic_action

def setup_model(self):
with SetVerbosity(self.verbose):
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000

def _get_pretrain_placeholders(self):
policy = self.step_model
return policy.obs_ph, tf.placeholder(tf.int32, [None]), policy.q_values
return policy.obs_ph, tf.placeholder(tf.int32, [None]), None, None, None, policy.q_values

def setup_model(self):

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/gail/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from stable_baselines.gail.model import GAIL
from stable_baselines.gail.dataset.dataset import ExpertDataset, DataLoader
from stable_baselines.gail.dataset.dataset import ExpertDataset, ExpertDatasetLSTM, DataLoader
from stable_baselines.gail.dataset.record_expert import generate_expert_traj
Loading