-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat Sebulba recurrent IQL #1148
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've looked through everything except the system file and it looks good, Sebulba utils especially! Just some relatively minor style changes
@@ -11,13 +11,13 @@ add_agent_id: True | |||
min_buffer_size: 32 | |||
update_batch_size: 1 # Number of vectorised gradient updates per device. | |||
|
|||
rollout_length: 2 # Number of environment steps per vectorised environment. | |||
rollout_length: 2 # Number of environment steps per vectorised enviro²nment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how the ^2 got in there 😂
rollout_length: 2 # Number of environment steps per vectorised enviro²nment. | |
rollout_length: 2 # Number of environment steps per vectorised environment. |
@@ -21,22 +21,22 @@ | |||
from jumanji.env import State | |||
from typing_extensions import NamedTuple, TypeAlias | |||
|
|||
from mava.types import Observation | |||
from mava.types import MavaObservation, Observation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we still use Observation
? Can we use MavaObservation
everywhere?
# PPO specifique check | ||
if "num_minibatches" in config.system: | ||
assert num_eval_samples % config.system.num_minibatches == 0, ( | ||
f"Number of training samples per evaluator ({num_eval_samples})" | ||
+ f"must be divisible by num_minibatches ({config.system.num_minibatches})." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A thought on this, maybe we can split these up into multiple methods e.g check_num_updates
, check_num_envs
etc. Then have a check_sebulba_config_ppo
, check_anakin_config_ppo
and a check_sebulba_config_iql
which will use the relevant methods?
|
||
# todo: remove the ppo dependencies when we make sebulba for other systems |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point though, maybe there's something we can do about it 🤔
Maybe a protocol like that has action, obs, reward
, not sure if there's any other common attributes?
|
||
|
||
# from https://github.com/EdanToledo/Stoix/blob/feat/sebulba-dqn/stoix/utils/rate_limiters.py | ||
class RateLimiter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is getting quite long, can we structure it like this:
-utils/
--sebulba/
---utils.py
---rate_limiters.py
---pipelines.py
|
||
self.rate_limiter.sample() | ||
|
||
if not self._queue.empty(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this to metrics_queue
, wasn't clear what this was storing
|
||
self.inserts = 0.0 | ||
self.samples = 0 | ||
self.deletes = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if we need deletes here?
def __init__( | ||
self, samples_per_insert: float, min_size_to_sample: int, min_diff: float, max_diff: float | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please add a good doc string here please 🙏
Raises: | ||
ValueError: If error_buffer is smaller than max(1.0, samples_per_inserts). | ||
""" | ||
if isinstance(error_buffer, float) or isinstance(error_buffer, int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
if isinstance(error_buffer, float) or isinstance(error_buffer, int): | |
if isinstance(error_buffer, (int, float)): |
terminated = np.repeat( | ||
terminated[..., np.newaxis], repeats=self.num_agents, axis=-1 | ||
) # (B,) --> (B, N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this already happen for smax and lbf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work here! Really minor changes required. Happy to merge this pending some benchmarks
action = eps_greedy_dist.sample(seed=key) | ||
action = action[0, ...] # (1, B, A) -> (B, A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit safer as this will error if actions 0 dim is every larger than 1
action = eps_greedy_dist.sample(seed=key) | |
action = action[0, ...] # (1, B, A) -> (B, A) | |
action = eps_greedy_dist.sample(seed=key).squeeze(0) # (B, A) |
next_timestep = env.step(cpu_action) | ||
|
||
# Prepare the transation | ||
terminal = (1 - timestep.discount[..., 0, jnp.newaxis]).astype(bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure we want to remove the agent dim here?
target: Array, | ||
) -> Tuple[Array, Metrics]: | ||
# axes switched here to scan over time | ||
hidden_state, obs_term_or_trunc = prep_inputs_to_scannedrnn(obs, term_or_trunc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general comment, I think this would be a lot easier to read if we used done
to mean term_or_trunc
which I think is a reasonable thing. Would have to make the change in anakin also though :/
timing_dict = tree.map(lambda *x: np.mean(x), *rollout_times) | learn_times | ||
timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 things, can we call this time_metrics and can you add a shape explainer comment as it's a bit hard to work out what is happening here
""" | ||
|
||
eps = jnp.maximum( | ||
config.system.eps_min, 1 - (t / config.system.eps_decay) * (1 - config.system.eps_min) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice if we could set a different decay per actor, although I think that's out of scope for this PR. Maybe if you could make an issue to add in some of the ape-X DQN features that would be great
]: | ||
"""Initialise learner_fn, network and learner state.""" | ||
|
||
# create temporory envoirnments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
# create temporory envoirnments. | |
# create temporary environments. |
@@ -31,3 +31,7 @@ gamma: 0.99 # discount factor | |||
|
|||
eps_min: 0.05 | |||
eps_decay: 1e5 | |||
|
|||
# --- Sebulba parameters --- | |||
data_sample_mean: 150 # Average number of times the learner should sample each item from the replay buffer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rather call this: mean_data_sample_rate
. Wasn't clear to me what it was when I read it in the system file
config.sample_per_insert = config.system.data_sample_mean * insert_to_sample_ratio | ||
config.tolerance = config.sample_per_insert * config.system.error_tolerance | ||
|
||
min_num_inserts = max( | ||
config.system.sample_sequence_length // config.system.rollout_length, | ||
config.system.min_buffer_size // config.system.rollout_length, | ||
1, | ||
) | ||
|
||
rate_limiter = SampleToInsertRatio(config.sample_per_insert, min_num_inserts, config.tolerance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably put them in the system config so it's easier to find on things like neptune
config.sample_per_insert = config.system.data_sample_mean * insert_to_sample_ratio | |
config.tolerance = config.sample_per_insert * config.system.error_tolerance | |
min_num_inserts = max( | |
config.system.sample_sequence_length // config.system.rollout_length, | |
config.system.min_buffer_size // config.system.rollout_length, | |
1, | |
) | |
rate_limiter = SampleToInsertRatio(config.sample_per_insert, min_num_inserts, config.tolerance) | |
config.system.sample_per_insert = config.system.data_sample_mean * insert_to_sample_ratio | |
config.system.tolerance = config.sample_per_insert * config.system.error_tolerance | |
min_num_inserts = max( | |
config.system.sample_sequence_length // config.system.rollout_length, | |
config.system.min_buffer_size // config.system.rollout_length, | |
1, | |
) | |
rate_limiter = SampleToInsertRatio(config.system.sample_per_insert, min_num_inserts, config.system.tolerance) |
train_metrics["learner_step"] = (eval_step + 1) * config.system.num_updates_per_eval | ||
train_metrics["learner_steps_per_second"] = ( | ||
config.system.num_updates_per_eval | ||
) / time_metrics["learner_time_per_eval"] | ||
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should always log train metrics even if an episode hasn't finished yet, what do you think?
episode_return=episode_return, | ||
) | ||
|
||
if config.arch.absolute_metric and max_episode_return <= episode_return: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never know what order the bools will be evaluated in so I always add brackets, because it might be doing (config.arch.absolute_metric and max_episode_return) <= episode_return
if config.arch.absolute_metric and max_episode_return <= episode_return: | |
if config.arch.absolute_metric and (max_episode_return <= episode_return): |
What?
A recurrent IQL implementation using the Sebulba architecture.
Why?
Offline Sebulba base and non-jax envs in Mava.
How?
Mixed the Sebulba structure from PPO with the learner code from Anakin IQL.