Skip to content

Commit

Permalink
fix: use cached action_spec
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Nov 20, 2024
1 parent e4a9008 commit 8880dc6
Show file tree
Hide file tree
Showing 12 changed files with 14 additions and 19 deletions.
2 changes: 1 addition & 1 deletion mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/rec_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:
- arch: anakin
- system: q_learning/rec_iql
- network: rnn # [rnn, rcnn]
- env: smax # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe]
- env: smax # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax]

hydra:
searchpath:
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/mat/anakin/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def learner_setup(
init_x = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[None, ...], init_x)

_, action_space_type = get_action_head(env.action_spec())
_, action_space_type = get_action_head(env.action_spec)

if action_space_type == "discrete":
init_action = jnp.zeros((1, config.system.num_agents), dtype=jnp.int32)
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def learner_setup(
# Define network and optimisers.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def learner_setup(
# Define network and optimiser.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
8 changes: 2 additions & 6 deletions mava/systems/sable/anakin/ff_sable.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,13 @@ def learner_setup(
# Get available TPU cores.
n_devices = len(jax.devices())

# Get number of agents.
config.system.num_agents = env.num_agents

# PRNG keys.
key, net_key = keys

# Get number of agents and actions.
action_dim = env.action_dim
n_agents = env.action_spec().shape[0]
n_agents = env.num_agents
config.system.num_agents = n_agents
config.system.num_actions = action_dim

# Setting the chunksize - many agent problems require chunking agents
# Create a dummy decay factor for FF Sable
Expand All @@ -397,7 +393,7 @@ def learner_setup(
# Set positional encoding to False, since ff-sable does not use temporal dependencies.
config.network.memory_config.timestep_positional_encoding = False

_, action_space_type = get_action_head(env.action_spec())
_, action_space_type = get_action_head(env.action_spec)

# Define network.
sable_network = SableNetwork(
Expand Down
5 changes: 2 additions & 3 deletions mava/systems/sable/anakin/rec_sable.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,8 @@ def learner_setup(

# Get number of agents and actions.
action_dim = env.action_dim
n_agents = env.action_spec().shape[0]
n_agents = env.num_agents
config.system.num_agents = n_agents
config.system.num_actions = action_dim

# Setting the chunksize - smaller chunks save memory at the cost of speed
if config.network.memory_config.timestep_chunk_size:
Expand All @@ -429,7 +428,7 @@ def learner_setup(
else:
config.network.memory_config.chunk_size = config.system.rollout_length * n_agents

_, action_space_type = get_action_head(env.action_spec())
_, action_space_type = get_action_head(env.action_spec)

# Define network.
sable_network = SableNetwork(
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/sac/anakin/ff_hasac.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def replicate(x: Any) -> Any:

# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(
action_head, action_dim=env.action_dim, independent_std=False
)
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/sac/anakin/ff_isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def replicate(x: Any) -> Any:

# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(
action_head, action_dim=env.action_dim, independent_std=False
)
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/sac/anakin/ff_masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def replicate(x: Any) -> Any:

# Making actor network
actor_torso = hydra.utils.instantiate(cfg.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(
action_head, action_dim=env.action_dim, independent_std=False
)
Expand Down

0 comments on commit 8880dc6

Please sign in to comment.