-
Notifications
You must be signed in to change notification settings - Fork 724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding LSTM support to pretrain #315
base: master
Are you sure you want to change the base?
Changes from 43 commits
a9f7e30
c708a37
169a80b
b0ee4c7
027891f
b7541bd
11a9d00
004e3fb
4d14812
167a337
187f16e
1827b2d
34e7bde
54e5c01
c43b39a
c7b795a
920ac7b
ccddbb2
ee29e78
938a4f8
a952d02
a2a94ad
58ddd30
77538da
a124dcb
c4d9c47
d75c01e
4d87f41
6ab6728
b9e0fc0
fa4bbcf
b30413e
40a94ad
cba1030
9ed3cfa
96567be
4bfb988
30bdb19
06ba9da
a08f420
17c04ac
1406af3
f448eea
9770a12
eda5b8f
428022e
694e6c1
f11f732
d8685e4
e03be1d
2f6da05
570b8d9
e0bb120
2ee1300
988ba5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,9 +87,20 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0. | |
|
||
def _get_pretrain_placeholders(self): | ||
policy = self.train_model | ||
|
||
if self.initial_state is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should rather check the |
||
states_ph = None | ||
snew_ph = None | ||
dones_ph = None | ||
else: | ||
states_ph = policy.states_ph | ||
snew_ph = policy.snew | ||
dones_ph = policy.dones_ph | ||
|
||
if isinstance(self.action_space, gym.spaces.Discrete): | ||
return policy.obs_ph, self.actions_ph, policy.policy | ||
return policy.obs_ph, self.actions_ph, policy.deterministic_action | ||
return policy.obs_ph, self.actions_ph, states_ph, snew_ph, dones_ph, policy.policy | ||
return policy.obs_ph, self.actions_ph, states_ph, snew_ph, dones_ph,\ | ||
policy.deterministic_action | ||
|
||
def setup_model(self): | ||
with SetVerbosity(self.verbose): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,8 +152,18 @@ def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5, | |
def _get_pretrain_placeholders(self): | ||
policy = self.step_model | ||
action_ph = policy.pdtype.sample_placeholder([None]) | ||
|
||
if self.initial_state is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same remark as before |
||
states_ph = None | ||
snew_ph = None | ||
dones_ph = None | ||
else: | ||
states_ph = policy.states_ph | ||
snew_ph = policy.snew | ||
dones_ph = policy.dones_ph | ||
|
||
if isinstance(self.action_space, Discrete): | ||
return policy.obs_ph, action_ph, policy.policy | ||
return policy.obs_ph, action_ph, states_ph, snew_ph, dones_ph, policy.policy | ||
raise NotImplementedError('Only discrete actions are supported for ACER for now') | ||
|
||
def set_env(self, env): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,10 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, pol | |
self.sess = None | ||
self.params = None | ||
self._param_load_ops = None | ||
self.initial_state = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't need that variable, there is the recurrent attribute for that |
||
self.n_batch = None | ||
self.nminibatches = None | ||
self.n_steps = None | ||
|
||
if env is not None: | ||
if isinstance(env, str): | ||
|
@@ -246,13 +250,24 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, | |
else: | ||
val_interval = int(n_epochs / 10) | ||
|
||
use_lstm = self.initial_state is not None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same remark, you can use the recurrent attribute |
||
|
||
if use_lstm: | ||
if self.nminibatches is None: | ||
envs_per_batch = self.n_envs * self.n_steps | ||
else: | ||
batch_size = self.n_batch // self.nminibatches | ||
envs_per_batch = batch_size // self.n_steps | ||
|
||
with self.graph.as_default(): | ||
with tf.variable_scope('pretrain'): | ||
if continuous_actions: | ||
obs_ph, actions_ph, deterministic_actions_ph = self._get_pretrain_placeholders() | ||
obs_ph, actions_ph, states_ph, snew_ph, dones_ph, \ | ||
deterministic_actions_ph = self._get_pretrain_placeholders() | ||
loss = tf.reduce_mean(tf.square(actions_ph - deterministic_actions_ph)) | ||
else: | ||
obs_ph, actions_ph, actions_logits_ph = self._get_pretrain_placeholders() | ||
obs_ph, actions_ph, states_ph, snew_ph, dones_ph, \ | ||
actions_logits_ph = self._get_pretrain_placeholders() | ||
# actions_ph has a shape if (n_batch,), we reshape it to (n_batch, 1) | ||
# so no additional changes is needed in the dataloader | ||
actions_ph = tf.expand_dims(actions_ph, axis=1) | ||
|
@@ -272,13 +287,23 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, | |
|
||
for epoch_idx in range(int(n_epochs)): | ||
train_loss = 0.0 | ||
if use_lstm: | ||
state = self.initial_state[:envs_per_batch] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. initial state is an attribute of the policy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes and no. |
||
|
||
# Full pass on the training set | ||
for _ in range(len(dataset.train_loader)): | ||
expert_obs, expert_actions = dataset.get_next_batch('train') | ||
expert_obs, expert_actions, expert_mask = dataset.get_next_batch('train') | ||
feed_dict = { | ||
obs_ph: expert_obs, | ||
actions_ph: expert_actions, | ||
} | ||
|
||
if use_lstm: | ||
feed_dict.update({states_ph: state, dones_ph: expert_mask}) | ||
state, train_loss_, _ = self.sess.run([snew_ph, loss, optim_op], feed_dict) | ||
else: | ||
train_loss_, _ = self.sess.run([loss, optim_op], feed_dict) | ||
|
||
train_loss_, _ = self.sess.run([loss, optim_op], feed_dict) | ||
train_loss += train_loss_ | ||
|
||
|
@@ -288,9 +313,19 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, | |
val_loss = 0.0 | ||
# Full pass on the validation set | ||
for _ in range(len(dataset.val_loader)): | ||
expert_obs, expert_actions = dataset.get_next_batch('val') | ||
val_loss_, = self.sess.run([loss], {obs_ph: expert_obs, | ||
actions_ph: expert_actions}) | ||
expert_obs, expert_actions, expert_mask = dataset.get_next_batch('val') | ||
|
||
feed_dict = { | ||
obs_ph: expert_obs, | ||
actions_ph: expert_actions, | ||
} | ||
|
||
if use_lstm: | ||
feed_dict.update({states_ph: state, dones_ph: expert_mask}) | ||
val_loss_, = self.sess.run([loss], feed_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you only need to update the feeddict, |
||
else: | ||
val_loss_, = self.sess.run([loss], feed_dict) | ||
|
||
val_loss += val_loss_ | ||
|
||
val_loss /= len(dataset.val_loader) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from stable_baselines.gail.model import GAIL | ||
from stable_baselines.gail.dataset.dataset import ExpertDataset, DataLoader | ||
from stable_baselines.gail.dataset.dataset import ExpertDataset, ExpertDatasetLSTM, DataLoader | ||
from stable_baselines.gail.dataset.record_expert import generate_expert_traj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do:
so it's more compact, same for the else case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having the variable declaration vertical and horizontal interrupts the read flow. Yes It would make the code shorter but also les readable in my opinion. But I will change it if you really wont it that way.