Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat Sebulba recurrent IQL #1148

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open

Conversation

Louay-Ben-nessir
Copy link
Contributor

What?

A recurrent IQL implementation using the Sebulba architecture.

Why?

Offline Sebulba base and non-jax envs in Mava.

How?

Mixed the Sebulba structure from PPO with the learner code from Anakin IQL.

Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked through everything except the system file and it looks good, Sebulba utils especially! Just some relatively minor style changes

@@ -11,13 +11,13 @@ add_agent_id: True
min_buffer_size: 32
update_batch_size: 1 # Number of vectorised gradient updates per device.

rollout_length: 2 # Number of environment steps per vectorised environment.
rollout_length: 2 # Number of environment steps per vectorised enviro²nment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how the ^2 got in there 😂

Suggested change
rollout_length: 2 # Number of environment steps per vectorised enviro²nment.
rollout_length: 2 # Number of environment steps per vectorised environment.

@@ -21,22 +21,22 @@
from jumanji.env import State
from typing_extensions import NamedTuple, TypeAlias

from mava.types import Observation
from mava.types import MavaObservation, Observation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we still use Observation? Can we use MavaObservation everywhere?

Comment on lines +38 to +43
# PPO specifique check
if "num_minibatches" in config.system:
assert num_eval_samples % config.system.num_minibatches == 0, (
f"Number of training samples per evaluator ({num_eval_samples})"
+ f"must be divisible by num_minibatches ({config.system.num_minibatches})."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A thought on this, maybe we can split these up into multiple methods e.g check_num_updates, check_num_envs etc. Then have a check_sebulba_config_ppo, check_anakin_config_ppo and a check_sebulba_config_iql which will use the relevant methods?


# todo: remove the ppo dependencies when we make sebulba for other systems
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point though, maybe there's something we can do about it 🤔

Maybe a protocol like that has action, obs, reward, not sure if there's any other common attributes?



# from https://github.com/EdanToledo/Stoix/blob/feat/sebulba-dqn/stoix/utils/rate_limiters.py
class RateLimiter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is getting quite long, can we structure it like this:

-utils/
--sebulba/
---utils.py
---rate_limiters.py
---pipelines.py


self.rate_limiter.sample()

if not self._queue.empty():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to metrics_queue, wasn't clear what this was storing


self.inserts = 0.0
self.samples = 0
self.deletes = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we need deletes here?

Comment on lines +210 to +212
def __init__(
self, samples_per_insert: float, min_size_to_sample: int, min_diff: float, max_diff: float
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add a good doc string here please 🙏

Raises:
ValueError: If error_buffer is smaller than max(1.0, samples_per_inserts).
"""
if isinstance(error_buffer, float) or isinstance(error_buffer, int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
if isinstance(error_buffer, float) or isinstance(error_buffer, int):
if isinstance(error_buffer, (int, float)):

Comment on lines +275 to +277
terminated = np.repeat(
terminated[..., np.newaxis], repeats=self.num_agents, axis=-1
) # (B,) --> (B, N)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this already happen for smax and lbf?

Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work here! Really minor changes required. Happy to merge this pending some benchmarks

Comment on lines +122 to +123
action = eps_greedy_dist.sample(seed=key)
action = action[0, ...] # (1, B, A) -> (B, A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit safer as this will error if actions 0 dim is every larger than 1

Suggested change
action = eps_greedy_dist.sample(seed=key)
action = action[0, ...] # (1, B, A) -> (B, A)
action = eps_greedy_dist.sample(seed=key).squeeze(0) # (B, A)

next_timestep = env.step(cpu_action)

# Prepare the transation
terminal = (1 - timestep.discount[..., 0, jnp.newaxis]).astype(bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure we want to remove the agent dim here?

target: Array,
) -> Tuple[Array, Metrics]:
# axes switched here to scan over time
hidden_state, obs_term_or_trunc = prep_inputs_to_scannedrnn(obs, term_or_trunc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment, I think this would be a lot easier to read if we used done to mean term_or_trunc which I think is a reasonable thing. Would have to make the change in anakin also though :/

Comment on lines +423 to +424
timing_dict = tree.map(lambda *x: np.mean(x), *rollout_times) | learn_times
timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 things, can we call this time_metrics and can you add a shape explainer comment as it's a bit hard to work out what is happening here

"""

eps = jnp.maximum(
config.system.eps_min, 1 - (t / config.system.eps_decay) * (1 - config.system.eps_min)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we could set a different decay per actor, although I think that's out of scope for this PR. Maybe if you could make an issue to add in some of the ape-X DQN features that would be great

]:
"""Initialise learner_fn, network and learner state."""

# create temporory envoirnments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
# create temporory envoirnments.
# create temporary environments.

@@ -31,3 +31,7 @@ gamma: 0.99 # discount factor

eps_min: 0.05
eps_decay: 1e5

# --- Sebulba parameters ---
data_sample_mean: 150 # Average number of times the learner should sample each item from the replay buffer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rather call this: mean_data_sample_rate. Wasn't clear to me what it was when I read it in the system file

Comment on lines +584 to +593
config.sample_per_insert = config.system.data_sample_mean * insert_to_sample_ratio
config.tolerance = config.sample_per_insert * config.system.error_tolerance

min_num_inserts = max(
config.system.sample_sequence_length // config.system.rollout_length,
config.system.min_buffer_size // config.system.rollout_length,
1,
)

rate_limiter = SampleToInsertRatio(config.sample_per_insert, min_num_inserts, config.tolerance)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably put them in the system config so it's easier to find on things like neptune

Suggested change
config.sample_per_insert = config.system.data_sample_mean * insert_to_sample_ratio
config.tolerance = config.sample_per_insert * config.system.error_tolerance
min_num_inserts = max(
config.system.sample_sequence_length // config.system.rollout_length,
config.system.min_buffer_size // config.system.rollout_length,
1,
)
rate_limiter = SampleToInsertRatio(config.sample_per_insert, min_num_inserts, config.tolerance)
config.system.sample_per_insert = config.system.data_sample_mean * insert_to_sample_ratio
config.system.tolerance = config.sample_per_insert * config.system.error_tolerance
min_num_inserts = max(
config.system.sample_sequence_length // config.system.rollout_length,
config.system.min_buffer_size // config.system.rollout_length,
1,
)
rate_limiter = SampleToInsertRatio(config.system.sample_per_insert, min_num_inserts, config.system.tolerance)

Comment on lines +692 to +696
train_metrics["learner_step"] = (eval_step + 1) * config.system.num_updates_per_eval
train_metrics["learner_steps_per_second"] = (
config.system.num_updates_per_eval
) / time_metrics["learner_time_per_eval"]
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should always log train metrics even if an episode hasn't finished yet, what do you think?

episode_return=episode_return,
)

if config.arch.absolute_metric and max_episode_return <= episode_return:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never know what order the bools will be evaluated in so I always add brackets, because it might be doing (config.arch.absolute_metric and max_episode_return) <= episode_return

Suggested change
if config.arch.absolute_metric and max_episode_return <= episode_return:
if config.arch.absolute_metric and (max_episode_return <= episode_return):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants