From 50f323bb6ad5782e53383637628d618b68462fff Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Fri, 25 Oct 2024 08:40:01 -0700 Subject: [PATCH 01/11] work --- src/luxai_s3/env.py | 2 +- src/tests/test_gpu.py | 45 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 src/tests/test_gpu.py diff --git a/src/luxai_s3/env.py b/src/luxai_s3/env.py index c0a6a74..135b14d 100644 --- a/src/luxai_s3/env.py +++ b/src/luxai_s3/env.py @@ -426,7 +426,7 @@ def step( info[f"player_{k}"] = dict() return obs, state, reward, terminated_dict, truncated_dict, info - @functools.partial(jax.jit, static_argnums=(0, 2)) + @functools.partial(jax.jit, static_argnums=(0, )) def reset( self, key: chex.PRNGKey, params: Optional[EnvParams] = None ) -> Tuple[chex.Array, EnvState]: diff --git a/src/tests/test_gpu.py b/src/tests/test_gpu.py new file mode 100644 index 0000000..d0601d8 --- /dev/null +++ b/src/tests/test_gpu.py @@ -0,0 +1,45 @@ +import time +import jax +import flax.serialization +from luxai_s3.params import EnvParams +from luxai_s3.env import LuxAIS3Env +from luxai_s3.params import env_params_ranges +from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode +if __name__ == "__main__": + import numpy as np + np.random.seed(2) + + jax_env = LuxAIS3Env(auto_reset=True) + num_envs = 10 + seed = 0 + rng_key = jax.random.key(seed) + reset_fn = jax.vmap(jax_env.reset_env) + # sample random params initially + def sample_params(rng_key): + randomized_game_params = dict() + for k, v in env_params_ranges.items(): + rng_key, subkey = jax.random.split(rng_key) + randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v)) + params = EnvParams(**randomized_game_params) + return params + + rng_key, subkey = jax.random.split(rng_key) + env_params = jax.vmap(sample_params)(jax.random.split(subkey, num_envs)) + reset_fn(jax.random.split(subkey, num_envs), env_params) + + + + # env = LuxAIS3GymEnv() + # env = RecordEpisode(env, save_dir="episodes") + # obs, info = env.reset(seed=1) + + # print("Benchmarking time") + # stime = time.time() + # N = 100 + # # N = env.params.max_steps_in_match * env.params.match_count_per_episode + # for _ in range(N): + # env.step(env.action_space.sample()) + # etime = time.time() + # print(f"FPS: {N / (etime - stime)}") + + # env.close() \ No newline at end of file From 8d0a78f179871926de3c41319f9b031e171d25e2 Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Fri, 25 Oct 2024 08:40:37 -0700 Subject: [PATCH 02/11] work --- src/tests/test_gpu.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tests/test_gpu.py b/src/tests/test_gpu.py index d0601d8..f8ab20a 100644 --- a/src/tests/test_gpu.py +++ b/src/tests/test_gpu.py @@ -4,6 +4,7 @@ from luxai_s3.params import EnvParams from luxai_s3.env import LuxAIS3Env from luxai_s3.params import env_params_ranges +from luxai_s3.state import gen_map from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode if __name__ == "__main__": import numpy as np @@ -25,7 +26,8 @@ def sample_params(rng_key): rng_key, subkey = jax.random.split(rng_key) env_params = jax.vmap(sample_params)(jax.random.split(subkey, num_envs)) - reset_fn(jax.random.split(subkey, num_envs), env_params) + # reset_fn(jax.random.split(subkey, num_envs), env_params) + jax.vmap(gen_map)(jax.random.split(subkey, num_envs), env_params) From 7710c760c50218a52538a4dcea05ccc0dc220c4d Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Fri, 25 Oct 2024 10:00:12 -0700 Subject: [PATCH 03/11] formatting and batching --- src/luxai_runner/cli.py | 12 +- src/luxai_runner/episode.py | 24 +- src/luxai_runner/process.py | 6 +- src/luxai_s3/env.py | 698 +++++++++++++++++++++++++++------- src/luxai_s3/params.py | 20 +- src/luxai_s3/pygame_render.py | 23 +- src/luxai_s3/spaces.py | 3 +- src/luxai_s3/state.py | 110 ++---- src/luxai_s3/utils.py | 4 +- src/luxai_s3/wrappers.py | 92 +++-- src/tests/test.py | 16 +- src/tests/test_gpu.py | 50 ++- src/tests/test_gym.py | 8 +- 13 files changed, 766 insertions(+), 300 deletions(-) diff --git a/src/luxai_runner/cli.py b/src/luxai_runner/cli.py index 8b08a41..f37b455 100644 --- a/src/luxai_runner/cli.py +++ b/src/luxai_runner/cli.py @@ -15,6 +15,7 @@ from dataclasses import dataclass, field from typing import Optional + @dataclass class ReplayConfig: save_format: str = "json" @@ -22,6 +23,7 @@ class ReplayConfig: compressed_obs: bool = True """Whether to save compressed observations or not. Compressed observations do not contain the full observation at each step. In particular, the map information is stored as the first observation, subsequent observations only store the changes that happened.""" + @dataclass class Args: players: tyro.conf.Positional[List[str]] @@ -30,8 +32,8 @@ class Args: """Max episode length""" output: Optional[str] = None """Where to output replays. Default is none and no replay is generated""" - replay: ReplayConfig = field(default_factory=lambda : ReplayConfig()) - + replay: ReplayConfig = field(default_factory=lambda: ReplayConfig()) + verbose: int = 2 """Verbose Level (0 = silent, 1 = (game-ending errors, debug logs from agents), 2 = warnings (non-game ending invalid actions), 3 = info (system info, unit collisions) )""" seed: Optional[int] = None @@ -47,6 +49,7 @@ class Args: # skip_validate_action_space: bool = False # """Set this for a small performance increase. Note that turning this on means the engine assumes your submitted actions are valid. If your actions are not well formatted there could be errors""" + def main(): args = tyro.cli(Args) @@ -64,7 +67,9 @@ def main(): np.random.seed(args.seed) cfg = EpisodeConfig( players=args.players, - env_cls=lambda **kwargs: RecordEpisode(LuxAIS3GymEnv(numpy_output=True), save_on_close=False), + env_cls=lambda **kwargs: RecordEpisode( + LuxAIS3GymEnv(numpy_output=True), save_on_close=False + ), seed=args.seed, env_cfg=dict( # verbose=args.verbose, @@ -121,5 +126,6 @@ def main(): print("Time Elapsed: ", etime - stime) print("Rewards: ", results.rewards) + if __name__ == "__main__": main() diff --git a/src/luxai_runner/episode.py b/src/luxai_runner/episode.py index 703945d..0f4de9e 100644 --- a/src/luxai_runner/episode.py +++ b/src/luxai_runner/episode.py @@ -31,10 +31,12 @@ class EpisodeConfig: save_replay_path: Optional[str] = None replay_options: ReplayConfig = field(default_factory=ReplayConfig) + @dataclass class EpisodeResults: rewards: dict[str, float] + class Episode: def __init__(self, cfg: EpisodeConfig) -> None: self.cfg = cfg @@ -135,11 +137,11 @@ async def run(self): ) # if save_replay: - # replay = dict(observations=[], actions=[], dones=[], rewards=[]) - # if self.cfg.replay_options.compressed_obs: - # replay["observations"].append(state_obs) - # else: - # replay["observations"].append(self.env.state.get_obs()) + # replay = dict(observations=[], actions=[], dones=[], rewards=[]) + # if self.cfg.replay_options.compressed_obs: + # replay["observations"].append(state_obs) + # else: + # replay["observations"].append(self.env.state.get_obs()) i = 0 while not game_done: @@ -149,7 +151,10 @@ async def run(self): action_coros = [] for player in players.values(): action = player.step( - obs=obs[player.agent], step=i, reward=rewards[player.agent], info=infos[player.agent] + obs=obs[player.agent], + step=i, + reward=rewards[player.agent], + info=infos[player.agent], ) action_coros += [action] agent_ids += [player.agent] @@ -166,7 +171,9 @@ async def run(self): else: print(f"{agent_id} sent a invalid action {action}") actions[agent_id] = None - new_state_obs, rewards, terminations, truncations, infos = self.env.step(actions) + new_state_obs, rewards, terminations, truncations, infos = self.env.step( + actions + ) i += 1 # TODO (stao): hard code to avoid using jax structs in the infos and sending those. infos = dict(player_0=dict(), player_1=dict()) @@ -195,5 +202,6 @@ async def run(self): await player.proc.cleanup() return EpisodeResults(rewards=rewards) + def close(self): - pass \ No newline at end of file + pass diff --git a/src/luxai_runner/process.py b/src/luxai_runner/process.py index be7d303..fabeb27 100644 --- a/src/luxai_runner/process.py +++ b/src/luxai_runner/process.py @@ -54,7 +54,11 @@ async def start(self): base_file_path = os.path.basename(self.file_path) if self.is_binary: self._agent_process = await asyncio.create_subprocess_exec( - f"{cwd}\{base_file_path}" if sys.platform.startswith('win') else f"./{base_file_path}", + ( + f"{cwd}\{base_file_path}" + if sys.platform.startswith("win") + else f"./{base_file_path}" + ), stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, diff --git a/src/luxai_s3/env.py b/src/luxai_s3/env.py index 135b14d..3a2072b 100644 --- a/src/luxai_s3/env.py +++ b/src/luxai_s3/env.py @@ -9,53 +9,99 @@ from gymnax.environments import environment, spaces from jax import lax -from luxai_s3.params import EnvParams +from luxai_s3.params import EnvParams, env_params_ranges from luxai_s3.spaces import MultiDiscrete -from luxai_s3.state import ASTEROID_TILE, ENERGY_NODE_FNS, NEBULA_TILE, EnvObs, EnvState, MapTile, UnitState, gen_state, spawn_unit +from luxai_s3.state import ( + ASTEROID_TILE, + ENERGY_NODE_FNS, + NEBULA_TILE, + EnvObs, + EnvState, + MapTile, + UnitState, + gen_state +) from luxai_s3.pygame_render import LuxAIPygameRenderer class LuxAIS3Env(environment.Environment): - def __init__(self, auto_reset=False, **kwargs): + def __init__( + self, auto_reset=False, fixed_env_params: EnvParams = EnvParams(), **kwargs + ): super().__init__(**kwargs) self.renderer = LuxAIPygameRenderer() self.auto_reset = auto_reset + self.fixed_env_params = fixed_env_params + """fixed env params for concrete/static values. Necessary for jit/vmap capability with randomly sampled maps which must of consistent shape""" @property def default_params(self) -> EnvParams: return EnvParams() - + def compute_unit_counts_map(self, state: EnvState, params: EnvParams): # map of total units per team on each tile, shape (num_teams, map_width, map_height) - unit_counts_map = jnp.zeros((params.num_teams, params.map_width, params.map_height), dtype=jnp.int32) + unit_counts_map = jnp.zeros( + (params.num_teams, params.map_width, params.map_height), dtype=jnp.int32 + ) + def update_unit_counts_map(unit_position, unit_mask, unit_counts_map): - unit_counts_map = unit_counts_map.at[unit_position[0], unit_position[1]].add(unit_mask) + unit_counts_map = unit_counts_map.at[ + unit_position[0], unit_position[1] + ].add(unit_mask) return unit_counts_map + for t in range(params.num_teams): - unit_counts_map = unit_counts_map.at[t].add(jnp.sum(jax.vmap(update_unit_counts_map, in_axes=(0, 0, None), out_axes=0)(state.units.position[t], state.units_mask[t], unit_counts_map[t]), axis=0)) + unit_counts_map = unit_counts_map.at[t].add( + jnp.sum( + jax.vmap(update_unit_counts_map, in_axes=(0, 0, None), out_axes=0)( + state.units.position[t], state.units_mask[t], unit_counts_map[t] + ), + axis=0, + ) + ) return unit_counts_map - + def compute_energy_features(self, state: EnvState, params: EnvParams): # first compute a array of shape (map_height, map_width, num_energy_nodes) with values equal to the distance of the tile to the energy node - mm = jnp.meshgrid(jnp.arange(params.map_width), jnp.arange(params.map_height)) - mm = jnp.stack([mm[0], mm[1]]).T # mm[x, y] gives [x, y] - distances_to_nodes = jax.vmap(lambda pos: jnp.linalg.norm(mm - pos, axis=-1))(state.energy_nodes) + mm = jnp.meshgrid(jnp.arange(self.fixed_env_params.map_width), jnp.arange(self.fixed_env_params.map_height)) + mm = jnp.stack([mm[0], mm[1]]).T # mm[x, y] gives [x, y] + distances_to_nodes = jax.vmap(lambda pos: jnp.linalg.norm(mm - pos, axis=-1))( + state.energy_nodes + ) + def compute_energy_field(node_fn_spec, distances_to_node, mask): fn_i, x, y, z = node_fn_spec - return jnp.where(mask, lax.switch(fn_i.astype(jnp.int16), ENERGY_NODE_FNS, distances_to_node, x, y, z), jnp.zeros_like(distances_to_node)) - energy_field = jax.vmap(compute_energy_field)(state.energy_node_fns, distances_to_nodes, state.energy_nodes_mask) - energy_field = jnp.where(energy_field.mean() < .25, energy_field + (.25 - energy_field.mean()), energy_field) + return jnp.where( + mask, + lax.switch( + fn_i.astype(jnp.int16), ENERGY_NODE_FNS, distances_to_node, x, y, z + ), + jnp.zeros_like(distances_to_node), + ) + + energy_field = jax.vmap(compute_energy_field)( + state.energy_node_fns, distances_to_nodes, state.energy_nodes_mask + ) + energy_field = jnp.where( + energy_field.mean() < 0.25, + energy_field + (0.25 - energy_field.mean()), + energy_field, + ) energy_field = jnp.round(energy_field.sum(0)).astype(jnp.int16) - energy_field = jnp.clip(energy_field, params.min_energy_per_tile, params.max_energy_per_tile) - state = state.replace(map_features=state.map_features.replace(energy=energy_field)) + energy_field = jnp.clip( + energy_field, params.min_energy_per_tile, params.max_energy_per_tile + ) + state = state.replace( + map_features=state.map_features.replace(energy=energy_field) + ) return state - - def compute_sensor_masks(self, state, params): - """Compute the vision power and sensor mask for both teams - + + def compute_sensor_masks(self, state, params: EnvParams): + """Compute the vision power and sensor mask for both teams + Algorithm: - For each team, generate a integer vision power array over the map. + For each team, generate a integer vision power array over the map. For each unit in team, add unit sensor range value (its kind of like the units sensing power/depth) to each tile the unit's sensor range Clamp the vision power array to range [0, unit_sensing_range]. @@ -63,25 +109,44 @@ def compute_sensor_masks(self, state, params): Now any time the vision power map has value > 0, the team can sense the tile. This forms the sensor mask """ - vision_power_map_padding = params.unit_sensor_range + vision_power_map_padding = self.fixed_env_params.unit_sensor_range vision_power_map = jnp.zeros( - shape=(params.num_teams, params.map_height + 2 * vision_power_map_padding, params.map_width + 2 * vision_power_map_padding), + shape=( + self.fixed_env_params.num_teams, + self.fixed_env_params.map_height + 2 * vision_power_map_padding, + self.fixed_env_params.map_width + 2 * vision_power_map_padding, + ), dtype=jnp.int16, ) # Update sensor mask based on the sensor range + max_sensor_range = env_params_ranges["unit_sensor_range"][-1] def update_vision_power_map(unit_pos, vision_power_map): - x, y = unit_pos - existing_vision_power = jax.lax.dynamic_slice(vision_power_map, start_indices=(x - params.unit_sensor_range + vision_power_map_padding, y - params.unit_sensor_range + vision_power_map_padding), slice_sizes=(params.unit_sensor_range * 2 + 1, params.unit_sensor_range * 2 + 1)) + x, y = unit_pos + existing_vision_power = jax.lax.dynamic_slice( + vision_power_map, + start_indices=( + x - max_sensor_range + vision_power_map_padding, + y - max_sensor_range + vision_power_map_padding, + ), + slice_sizes=( + max_sensor_range * 2 + 1, + max_sensor_range * 2 + 1, + ), + ) update = jnp.zeros_like(existing_vision_power) - for i in range(params.unit_sensor_range + 1): - update = update.at[i:params.unit_sensor_range * 2 + 1 - i, i:params.unit_sensor_range * 2 + 1 - i].set(i + 1) + for i in range(max_sensor_range + 1): + val = jnp.where(i > max_sensor_range - params.unit_sensor_range, i + 1 - params.unit_sensor_range, 0) + update = update.at[ + i : max_sensor_range * 2 + 1 - i, + i : max_sensor_range * 2 + 1 - i, + ].set(val) vision_power_map = jax.lax.dynamic_update_slice( vision_power_map, update=update + existing_vision_power, start_indices=( - x - params.unit_sensor_range + vision_power_map_padding, - y - params.unit_sensor_range + vision_power_map_padding, + x - max_sensor_range + vision_power_map_padding, + y - max_sensor_range + vision_power_map_padding, ), ) return vision_power_map @@ -105,22 +170,31 @@ def body_fun(carry, i): ) vision_power_map, _ = jax.lax.scan( - body_fun, vision_power_map, jnp.arange(params.max_units) + body_fun, vision_power_map, jnp.arange(self.fixed_env_params.max_units) ) return vision_power_map vision_power_map = jax.vmap(update_team_vision_power_map)( state.units, state.units_mask, vision_power_map ) - vision_power_map = vision_power_map[:, vision_power_map_padding:-vision_power_map_padding, vision_power_map_padding:-vision_power_map_padding] + vision_power_map = vision_power_map[ + :, + vision_power_map_padding:-vision_power_map_padding, + vision_power_map_padding:-vision_power_map_padding, + ] # handle nebula tiles - vision_power_map = vision_power_map - (state.map_features.tile_type == NEBULA_TILE) * params.nebula_tile_vision_reduction - + vision_power_map = ( + vision_power_map + - (state.map_features.tile_type == NEBULA_TILE) + * params.nebula_tile_vision_reduction + ) + sensor_mask = vision_power_map > 0 state = state.replace(sensor_mask=sensor_mask) state = state.replace(vision_power_map=vision_power_map) return state + # @functools.partial(jax.jit, static_argnums=(0, 4)) def step_env( self, @@ -129,17 +203,25 @@ def step_env( action: Union[int, float, chex.Array], params: EnvParams, ) -> Tuple[EnvObs, EnvState, jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]: - + state = self.compute_energy_features(state, params) action = jnp.stack([action["player_0"], action["player_1"]]) - + # remove all units if the match ended in the previous step indicated by a reset of match_steps to 0 - state = state.replace(units_mask=jnp.where(state.match_steps == 0, jnp.zeros_like(state.units_mask), state.units_mask)) + state = state.replace( + units_mask=jnp.where( + state.match_steps == 0, + jnp.zeros_like(state.units_mask), + state.units_mask, + ) + ) """remove units that have less than 0 energy""" # we remove units at the start of the timestep so that the visualizer can show the unit with negative energy and is marked for removal soon. - state = state.replace(units_mask=(state.units.energy[..., 0] >= 0) & state.units_mask) - + state = state.replace( + units_mask=(state.units.energy[..., 0] >= 0) & state.units_mask + ) + """ process unit movement """ # 0 is do nothing, 1 is move up, 2 is move right, 3 is move down, 4 is move left, 5 is sap # Define movement directions @@ -157,7 +239,9 @@ def step_env( def move_unit(unit: UnitState, action, mask): new_pos = unit.position + directions[action] # Check if the new position is on a map feature of value 2 - is_blocked = state.map_features.tile_type[new_pos[0], new_pos[1]] == ASTEROID_TILE + is_blocked = ( + state.map_features.tile_type[new_pos[0], new_pos[1]] == ASTEROID_TILE + ) enough_energy = unit.energy >= params.unit_move_cost # If blocked, keep the original position # new_pos = jnp.where(is_blocked, unit.position, new_pos) @@ -169,127 +253,288 @@ def move_unit(unit: UnitState, action, mask): [params.map_width - 1, params.map_height - 1], dtype=jnp.int16 ), ) - unit_moved = mask & ~is_blocked & enough_energy & (action < 5) & (action > 0) + unit_moved = ( + mask & ~is_blocked & enough_energy & (action < 5) & (action > 0) + ) # Update the unit's position only if it's active. Note energy is used if unit tries to move off map. Energy is not used if unit tries to move into an asteroid tile. - return UnitState(position=jnp.where(unit_moved, new_pos, unit.position), energy=jnp.where(unit_moved, unit.energy - params.unit_move_cost, unit.energy)) + return UnitState( + position=jnp.where(unit_moved, new_pos, unit.position), + energy=jnp.where( + unit_moved, unit.energy - params.unit_move_cost, unit.energy + ), + ) # Move units for both teams move_actions = action[..., 0] state = state.replace( units=jax.vmap( - lambda team_units, team_action, team_mask: jax.vmap(move_unit, in_axes=(0, 0, 0))( - team_units, team_action, team_mask - ), in_axes=(0, 0, 0) + lambda team_units, team_action, team_mask: jax.vmap( + move_unit, in_axes=(0, 0, 0) + )(team_units, team_action, team_mask), + in_axes=(0, 0, 0), )(state.units, move_actions, state.units_mask) ) - + original_unit_energy = state.units.energy """original amount of energy of all units""" - + """apply sap actions""" sap_action_mask = action[..., 0] == 5 sap_action_deltas = action[..., 1:] - - def sap_unit(current_energy: jnp.ndarray, all_units: UnitState, sap_action_mask, sap_action_deltas, units_mask): + + def sap_unit( + current_energy: jnp.ndarray, + all_units: UnitState, + sap_action_mask, + sap_action_deltas, + units_mask, + ): # TODO (stao): clean up this code. It is probably slower than it needs be and could be vmapped perhaps. for t in range(params.num_teams): - other_team_ids = jnp.array([t2 for t2 in range(params.num_teams) if t2 != t]) - team_sap_action_deltas = sap_action_deltas[t] # (max_units, 2) + other_team_ids = jnp.array( + [t2 for t2 in range(params.num_teams) if t2 != t] + ) + team_sap_action_deltas = sap_action_deltas[t] # (max_units, 2) team_sap_action_mask = sap_action_mask[t] - other_team_unit_mask = units_mask[other_team_ids] # (other_teams, max_units) - team_sapped_positions = all_units.position[t] + team_sap_action_deltas # (max_units, 2) + other_team_unit_mask = units_mask[ + other_team_ids + ] # (other_teams, max_units) + team_sapped_positions = ( + all_units.position[t] + team_sap_action_deltas + ) # (max_units, 2) # whether the unit is really sapping or not (needs to exist, have enough energy, and a valid sap action) - team_unit_sapped = units_mask[t] & team_sap_action_mask & (current_energy[t, :, 0] >= params.unit_sap_cost) & (jnp.max(jnp.abs(team_sap_action_deltas), axis=-1) <= params.unit_sap_range) # (max_units) - team_unit_sapped = team_unit_sapped & (team_sapped_positions >= 0).all(-1) & (team_sapped_positions[:, 0] < params.map_width) & (team_sapped_positions[:, 1] < params.map_height) + team_unit_sapped = ( + units_mask[t] + & team_sap_action_mask + & (current_energy[t, :, 0] >= params.unit_sap_cost) + & ( + jnp.max(jnp.abs(team_sap_action_deltas), axis=-1) + <= params.unit_sap_range + ) + ) # (max_units) + team_unit_sapped = ( + team_unit_sapped + & (team_sapped_positions >= 0).all(-1) + & (team_sapped_positions[:, 0] < params.map_width) + & (team_sapped_positions[:, 1] < params.map_height) + ) # the number of times other units are sapped - other_units_sapped_count = jnp.sum(jnp.all(all_units.position[other_team_ids][:, :, None] == team_sapped_positions[None], axis=-1), axis=-1, dtype=jnp.int16) # (len(other_team_ids), max_units) + other_units_sapped_count = jnp.sum( + jnp.all( + all_units.position[other_team_ids][:, :, None] + == team_sapped_positions[None], + axis=-1, + ), + axis=-1, + dtype=jnp.int16, + ) # (len(other_team_ids), max_units) # remove unit_sap_cost energy from opposition units that were in the middle of a sap action. all_units = all_units.replace( energy=all_units.energy.at[other_team_ids].set( - jnp.where(team_unit_sapped[None, :, None] & other_team_unit_mask[:, :, None] & (other_units_sapped_count[:, :, None] > 0), - all_units.energy[other_team_ids] - params.unit_sap_cost * other_units_sapped_count[:, :, None], + jnp.where( + team_unit_sapped[None, :, None] + & other_team_unit_mask[:, :, None] + & (other_units_sapped_count[:, :, None] > 0), all_units.energy[other_team_ids] + - params.unit_sap_cost + * other_units_sapped_count[:, :, None], + all_units.energy[other_team_ids], ) ) ) - + # remove unit_sap_cost * unit_sap_dropoff_factor energy from opposition units that were on tiles adjacent to the center of a sap action. - adjacent_offsets = jnp.array([ - [-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 1], [1, -1], [1, 0], [1, 1] - ]) - team_sapped_adjacent_positions = team_sapped_positions[:, None, :] + adjacent_offsets # (max_units, len(adjacent_offsets), 2) - other_units_adjacent_sapped_count = jnp.sum(jnp.all(all_units.position[other_team_ids][:, :, None] == team_sapped_adjacent_positions[None], axis=-1), axis=-1, dtype=jnp.int16) # (len(other_team_ids), max_units) + adjacent_offsets = jnp.array( + [ + [-1, -1], + [-1, 0], + [-1, 1], + [0, -1], + [0, 1], + [1, -1], + [1, 0], + [1, 1], + ] + ) + team_sapped_adjacent_positions = ( + team_sapped_positions[:, None, :] + adjacent_offsets + ) # (max_units, len(adjacent_offsets), 2) + other_units_adjacent_sapped_count = jnp.sum( + jnp.all( + all_units.position[other_team_ids][:, :, None] + == team_sapped_adjacent_positions[None], + axis=-1, + ), + axis=-1, + dtype=jnp.int16, + ) # (len(other_team_ids), max_units) all_units = all_units.replace( energy=all_units.energy.at[other_team_ids].set( - jnp.where(team_unit_sapped[None, :, None] & other_team_unit_mask[:, :, None] & (other_units_adjacent_sapped_count[:, :, None] > 0), - all_units.energy[other_team_ids] - jnp.array(params.unit_sap_cost * params.unit_sap_dropoff_factor * other_units_adjacent_sapped_count[:, :, None], dtype=jnp.int16), + jnp.where( + team_unit_sapped[None, :, None] + & other_team_unit_mask[:, :, None] + & (other_units_adjacent_sapped_count[:, :, None] > 0), all_units.energy[other_team_ids] + - jnp.array( + params.unit_sap_cost + * params.unit_sap_dropoff_factor + * other_units_adjacent_sapped_count[:, :, None], + dtype=jnp.int16, + ), + all_units.energy[other_team_ids], ) ) ) - + # remove unit_sap_cost energy from units that tried to sap some position within the unit's range - all_units = all_units.replace(energy=all_units.energy.at[t].set(jnp.where(team_unit_sapped[:, None], all_units.energy[t] - params.unit_sap_cost, all_units.energy[t]))) + all_units = all_units.replace( + energy=all_units.energy.at[t].set( + jnp.where( + team_unit_sapped[:, None], + all_units.energy[t] - params.unit_sap_cost, + all_units.energy[t], + ) + ) + ) return all_units - + state = state.replace( - units=sap_unit(original_unit_energy, state.units, sap_action_mask, sap_action_deltas, state.units_mask) + units=sap_unit( + original_unit_energy, + state.units, + sap_action_mask, + sap_action_deltas, + state.units_mask, + ) ) - + """resolve collisions and energy void fields""" - - # compute energy void fields for all teams and the energy + unit counts - unit_aggregate_energy_void_map = jnp.zeros(shape=(params.num_teams, params.map_width, params.map_height), dtype=jnp.float32) + + # compute energy void fields for all teams and the energy + unit counts + unit_aggregate_energy_void_map = jnp.zeros( + shape=(params.num_teams, params.map_width, params.map_height), + dtype=jnp.float32, + ) unit_counts_map = self.compute_unit_counts_map(state, params) # TODO (stao): this doesn't need to be a float? - unit_aggregate_energy_map = jnp.zeros(shape=(params.num_teams, params.map_width, params.map_height), dtype=jnp.float32) + unit_aggregate_energy_map = jnp.zeros( + shape=(params.num_teams, params.map_width, params.map_height), + dtype=jnp.float32, + ) for t in range(params.num_teams): + def scan_body(carry, x): agg_energy_void_map, agg_energy_map = carry unit_energy, unit_position, unit_mask = x - agg_energy_map = agg_energy_map.at[unit_position[0], unit_position[1]].add(unit_energy[0] * unit_mask) + agg_energy_map = agg_energy_map.at[ + unit_position[0], unit_position[1] + ].add(unit_energy[0] * unit_mask) for deltas in [(-1, 0), (1, 0), (0, -1), (0, 1)]: new_pos = unit_position + jnp.array(deltas) - in_map = (new_pos[0] >= 0) & (new_pos[0] < params.map_width) & (new_pos[1] >= 0) & (new_pos[1] < params.map_height) - agg_energy_void_map = agg_energy_void_map.at[new_pos[0], new_pos[1]].add(unit_energy[0] * unit_mask * in_map) + in_map = ( + (new_pos[0] >= 0) + & (new_pos[0] < params.map_width) + & (new_pos[1] >= 0) + & (new_pos[1] < params.map_height) + ) + agg_energy_void_map = agg_energy_void_map.at[ + new_pos[0], new_pos[1] + ].add(unit_energy[0] * unit_mask * in_map) return (agg_energy_void_map, agg_energy_map), None - agg_energy_void_map, agg_energy_map = jax.lax.scan(scan_body, (unit_aggregate_energy_void_map[t], unit_aggregate_energy_map[t]), (original_unit_energy[t], state.units.position[t], state.units_mask[t]))[0] - unit_aggregate_energy_void_map = unit_aggregate_energy_void_map.at[t].add(agg_energy_void_map) - unit_aggregate_energy_map = unit_aggregate_energy_map.at[t].add(agg_energy_map) + + agg_energy_void_map, agg_energy_map = jax.lax.scan( + scan_body, + (unit_aggregate_energy_void_map[t], unit_aggregate_energy_map[t]), + (original_unit_energy[t], state.units.position[t], state.units_mask[t]), + )[0] + unit_aggregate_energy_void_map = unit_aggregate_energy_void_map.at[t].add( + agg_energy_void_map + ) + unit_aggregate_energy_map = unit_aggregate_energy_map.at[t].add( + agg_energy_map + ) # resolve collisions and keep only the surviving units for t in range(params.num_teams): - other_team_ids = jnp.array([t2 for t2 in range(params.num_teams) if t2 != t]) + other_team_ids = jnp.array( + [t2 for t2 in range(params.num_teams) if t2 != t] + ) # get the energy map for the current team - opposing_unit_counts_map = unit_counts_map[other_team_ids].sum(axis=0) # (map_width, map_height) + opposing_unit_counts_map = unit_counts_map[other_team_ids].sum( + axis=0 + ) # (map_width, map_height) team_energy_map = unit_aggregate_energy_map[t] - opposing_aggregate_energy_map = unit_aggregate_energy_map[other_team_ids].max(axis=0) # (map_width, map_height) + opposing_aggregate_energy_map = unit_aggregate_energy_map[ + other_team_ids + ].max( + axis=0 + ) # (map_width, map_height) # unit survives if there are opposing units on the tile, and if the opposing unit stack has less energy on the tile than the current unit surviving_unit_mask = jax.vmap( - lambda unit_position : (opposing_unit_counts_map[unit_position[0], unit_position[1]] == 0) | - (opposing_aggregate_energy_map[unit_position[0], unit_position[1]] < team_energy_map[unit_position[0], unit_position[1]]) + lambda unit_position: ( + opposing_unit_counts_map[unit_position[0], unit_position[1]] == 0 + ) + | ( + opposing_aggregate_energy_map[unit_position[0], unit_position[1]] + < team_energy_map[unit_position[0], unit_position[1]] + ) )(state.units.position[t]) - state = state.replace(units_mask=state.units_mask.at[t].set(surviving_unit_mask & state.units_mask[t])) + state = state.replace( + units_mask=state.units_mask.at[t].set( + surviving_unit_mask & state.units_mask[t] + ) + ) # apply energy void fields for t in range(params.num_teams): - other_team_ids = jnp.array([t2 for t2 in range(params.num_teams) if t2 != t]) - oppposition_energy_void_map = unit_aggregate_energy_void_map[other_team_ids].sum(axis=0) # (map_width, map_height) + other_team_ids = jnp.array( + [t2 for t2 in range(params.num_teams) if t2 != t] + ) + oppposition_energy_void_map = unit_aggregate_energy_void_map[ + other_team_ids + ].sum( + axis=0 + ) # (map_width, map_height) # unit on team t loses energy to void field equal to params.unit_energy_void_factor * void_energy / num units stacked with unit on the same tile - team_unit_energy = state.units.energy[t] - jnp.floor(jax.vmap(lambda unit_position: params.unit_energy_void_factor * oppposition_energy_void_map[unit_position[0], unit_position[1]] / unit_counts_map[t][unit_position[0], unit_position[1]])(state.units.position[t])[..., None]).astype(jnp.int16) - state = state.replace(units=state.units.replace(energy=state.units.energy.at[t].set(team_unit_energy))) + team_unit_energy = state.units.energy[t] - jnp.floor( + jax.vmap( + lambda unit_position: params.unit_energy_void_factor + * oppposition_energy_void_map[unit_position[0], unit_position[1]] + / unit_counts_map[t][unit_position[0], unit_position[1]] + )(state.units.position[t])[..., None] + ).astype(jnp.int16) + state = state.replace( + units=state.units.replace( + energy=state.units.energy.at[t].set(team_unit_energy) + ) + ) """apply energy field to the units""" + # Update unit energy based on the energy field and nebula tileof their current position def update_unit_energy(unit: UnitState, mask): x, y = unit.position - energy_gain = state.map_features.energy[x, y] - (state.map_features.tile_type[x, y] == NEBULA_TILE) * params.nebula_tile_energy_reduction + energy_gain = ( + state.map_features.energy[x, y] + - (state.map_features.tile_type[x, y] == NEBULA_TILE) + * params.nebula_tile_energy_reduction + ) # if energy gain is less than 0 # new_energy = jnp.where((unit.energy < 0) & (energy_gain < 0)) - new_energy = jnp.clip(unit.energy + energy_gain, params.min_unit_energy, params.max_unit_energy) + new_energy = jnp.clip( + unit.energy + energy_gain, + params.min_unit_energy, + params.max_unit_energy, + ) # if unit already had negative energy due to opposition units and after energy field/nebula tile it is still below 0, then it will be removed next step # and we keep its energy value at whatever it is - new_energy = jnp.where((unit.energy < 0) & (unit.energy + energy_gain < 0), unit.energy, new_energy) - return UnitState(position=unit.position, energy=jnp.where(mask, new_energy, unit.energy)) + new_energy = jnp.where( + (unit.energy < 0) & (unit.energy + energy_gain < 0), + unit.energy, + new_energy, + ) + return UnitState( + position=unit.position, energy=jnp.where(mask, new_energy, unit.energy) + ) # Apply the energy update for all units of both teams state = state.replace( @@ -299,69 +544,193 @@ def update_unit_energy(unit: UnitState, mask): ) )(state.units, state.units_mask) ) - + """spawn new units in""" - spawn_units_in = (state.match_steps % params.spawn_rate == 0) + spawn_units_in = state.match_steps % params.spawn_rate == 0 + # TODO (stao): only logic in code that probably doesn't not handle more than 2 teams, everything else is vmapped across teams def spawn_team_units(state: EnvState): team_0_unit_count = state.units_mask[0].sum() team_1_unit_count = state.units_mask[1].sum() team_0_new_unit_id = state.units_mask[0].argmin() team_1_new_unit_id = state.units_mask[1].argmin() - state = state.replace(units=state.units.replace(position=jnp.where(team_0_unit_count < params.max_units, state.units.position.at[0, team_0_new_unit_id, :].set(jnp.array([0, 0], dtype=jnp.int16)), state.units.position))) - state = state.replace(units=state.units.replace(energy=jnp.where(team_0_unit_count < params.max_units, state.units.energy.at[0, team_0_new_unit_id, :].set(jnp.array([params.init_unit_energy], dtype=jnp.int16)), state.units.energy))) - state = state.replace(units=state.units.replace(position=jnp.where(team_1_unit_count < params.max_units, state.units.position.at[1, team_1_new_unit_id, :].set(jnp.array([params.map_width - 1, params.map_height - 1], dtype=jnp.int16)), state.units.position))) - state = state.replace(units=state.units.replace(energy=jnp.where(team_1_unit_count < params.max_units, state.units.energy.at[1, team_1_new_unit_id, :].set(jnp.array([params.init_unit_energy], dtype=jnp.int16)), state.units.energy))) - state = state.replace(units_mask=state.units_mask.at[0, team_0_new_unit_id].set(jnp.where(team_0_unit_count < params.max_units, True, state.units_mask[0, team_0_new_unit_id]))) - state = state.replace(units_mask=state.units_mask.at[1, team_1_new_unit_id].set(jnp.where(team_1_unit_count < params.max_units, True, state.units_mask[1, team_1_new_unit_id]))) + state = state.replace( + units=state.units.replace( + position=jnp.where( + team_0_unit_count < params.max_units, + state.units.position.at[0, team_0_new_unit_id, :].set( + jnp.array([0, 0], dtype=jnp.int16) + ), + state.units.position, + ) + ) + ) + state = state.replace( + units=state.units.replace( + energy=jnp.where( + team_0_unit_count < params.max_units, + state.units.energy.at[0, team_0_new_unit_id, :].set( + jnp.array([params.init_unit_energy], dtype=jnp.int16) + ), + state.units.energy, + ) + ) + ) + state = state.replace( + units=state.units.replace( + position=jnp.where( + team_1_unit_count < params.max_units, + state.units.position.at[1, team_1_new_unit_id, :].set( + jnp.array( + [params.map_width - 1, params.map_height - 1], + dtype=jnp.int16, + ) + ), + state.units.position, + ) + ) + ) + state = state.replace( + units=state.units.replace( + energy=jnp.where( + team_1_unit_count < params.max_units, + state.units.energy.at[1, team_1_new_unit_id, :].set( + jnp.array([params.init_unit_energy], dtype=jnp.int16) + ), + state.units.energy, + ) + ) + ) + state = state.replace( + units_mask=state.units_mask.at[0, team_0_new_unit_id].set( + jnp.where( + team_0_unit_count < params.max_units, + True, + state.units_mask[0, team_0_new_unit_id], + ) + ) + ) + state = state.replace( + units_mask=state.units_mask.at[1, team_1_new_unit_id].set( + jnp.where( + team_1_unit_count < params.max_units, + True, + state.units_mask[1, team_1_new_unit_id], + ) + ) + ) # state = jnp.where(team_0_unit_count < params.max_units, spawn_unit(state, 0, team_0_new_unit_id, [0, 0], params), state) # state = jnp.where(team_1_unit_count < params.max_units, spawn_unit(state, 1, team_1_new_unit_id, [params.map_width - 1, params.map_height - 1], params), state) return state - state = jax.lax.cond(spawn_units_in, lambda: spawn_team_units(state), lambda: state) + + state = jax.lax.cond( + spawn_units_in, lambda: spawn_team_units(state), lambda: state + ) state = self.compute_sensor_masks(state, params) - + # Shift objects around in space # Move the nebula tiles in state.map_features.tile_types up by 1 and to the right by 1 # this is also symmetric nebula tile movement - new_tile_types_map = jnp.roll(state.map_features.tile_type, shift=(1 * jnp.sign(params.nebula_tile_drift_speed), -1 * jnp.sign(params.nebula_tile_drift_speed)), axis=(0, 1)) - new_tile_types_map = jnp.where(state.steps * params.nebula_tile_drift_speed % 1 == 0, new_tile_types_map, state.map_features.tile_type) + new_tile_types_map = jnp.roll( + state.map_features.tile_type, + shift=( + 1 * jnp.sign(params.nebula_tile_drift_speed), + -1 * jnp.sign(params.nebula_tile_drift_speed), + ), + axis=(0, 1), + ) + new_tile_types_map = jnp.where( + state.steps * params.nebula_tile_drift_speed % 1 == 0, + new_tile_types_map, + state.map_features.tile_type, + ) # new_energy_nodes = state.energy_nodes + jnp.array([1 * jnp.sign(params.energy_node_drift_speed), -1 * jnp.sign(params.energy_node_drift_speed)]) - - energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16) - energy_node_deltas_symmetric = jnp.stack([-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1) + + energy_node_deltas = jnp.round( + jax.random.uniform( + key=key, + shape=(params.max_energy_nodes // 2, 2), + minval=-params.energy_node_drift_magnitude, + maxval=params.energy_node_drift_magnitude, + ) + ).astype(jnp.int16) + energy_node_deltas_symmetric = jnp.stack( + [-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1 + ) # TODO symmetric movement # energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16) - energy_node_deltas = jnp.concatenate((energy_node_deltas, energy_node_deltas_symmetric)) - new_energy_nodes = jnp.clip(state.energy_nodes + energy_node_deltas, min=jnp.array([0, 0]), max=jnp.array([params.map_width, params.map_height])) - new_energy_nodes = jnp.where(state.steps * params.energy_node_drift_speed % 1 == 0, new_energy_nodes, state.energy_nodes) - state = state.replace(map_features=state.map_features.replace(tile_type=new_tile_types_map), energy_nodes=new_energy_nodes) + energy_node_deltas = jnp.concatenate( + (energy_node_deltas, energy_node_deltas_symmetric) + ) + new_energy_nodes = jnp.clip( + state.energy_nodes + energy_node_deltas, + min=jnp.array([0, 0]), + max=jnp.array([params.map_width, params.map_height]), + ) + new_energy_nodes = jnp.where( + state.steps * params.energy_node_drift_speed % 1 == 0, + new_energy_nodes, + state.energy_nodes, + ) + state = state.replace( + map_features=state.map_features.replace(tile_type=new_tile_types_map), + energy_nodes=new_energy_nodes, + ) - # Compute relic scores def team_relic_score(unit_counts_map): scores = (unit_counts_map > 0) & (state.relic_nodes_map_weights > 0) return jnp.sum(scores, dtype=jnp.int32) - + # note we need to recompue unit counts since units can get removed due to collisions - team_scores = jax.vmap(team_relic_score)(self.compute_unit_counts_map(state, params)) - # Update team points - state = state.replace( - team_points=state.team_points + team_scores + team_scores = jax.vmap(team_relic_score)( + self.compute_unit_counts_map(state, params) ) + # Update team points + state = state.replace(team_points=state.team_points + team_scores) # if match ended, then remove all units, update team wins, reset team points - winner_by_points = jnp.where(state.team_points.max() > state.team_points.min(), jnp.argmax(state.team_points), -1) - winner_by_energy = jnp.sum(state.units.energy[..., 0] * state.units_mask, axis=1) - winner_by_energy = jnp.where(winner_by_energy.max() > winner_by_energy.min(), jnp.argmax(winner_by_energy), -1) + winner_by_points = jnp.where( + state.team_points.max() > state.team_points.min(), + jnp.argmax(state.team_points), + -1, + ) + winner_by_energy = jnp.sum( + state.units.energy[..., 0] * state.units_mask, axis=1 + ) + winner_by_energy = jnp.where( + winner_by_energy.max() > winner_by_energy.min(), + jnp.argmax(winner_by_energy), + -1, + ) - winner = jnp.where(winner_by_points != -1, winner_by_points, jnp.where(winner_by_energy != -1, winner_by_energy, jax.random.randint(key, shape=(), minval=0, maxval=params.num_teams))) + winner = jnp.where( + winner_by_points != -1, + winner_by_points, + jnp.where( + winner_by_energy != -1, + winner_by_energy, + jax.random.randint(key, shape=(), minval=0, maxval=params.num_teams), + ), + ) match_ended = state.match_steps >= params.max_steps_in_match - - state = state.replace(match_steps=jnp.where(match_ended, -1, state.match_steps), team_points=jnp.where(match_ended, jnp.zeros_like(state.team_points), state.team_points), team_wins=jnp.where(match_ended, state.team_wins.at[winner].add(1), state.team_wins)) + + state = state.replace( + match_steps=jnp.where(match_ended, -1, state.match_steps), + team_points=jnp.where( + match_ended, jnp.zeros_like(state.team_points), state.team_points + ), + team_wins=jnp.where( + match_ended, state.team_wins.at[winner].add(1), state.team_wins + ), + ) # Update state's step count state = state.replace(steps=state.steps + 1, match_steps=state.match_steps + 1) - truncated = state.steps >= (params.max_steps_in_match + 1) * params.match_count_per_episode + truncated = ( + state.steps + >= (params.max_steps_in_match + 1) * params.match_count_per_episode + ) reward = dict() for k in range(params.num_teams): reward[f"player_{k}"] = state.team_wins[k] @@ -380,7 +749,18 @@ def reset_env( ) -> Tuple[EnvObs, EnvState]: """Reset environment state by sampling initial position.""" - state = gen_state(key=key, params=params) + state = gen_state( + key=key, + env_params=params, + max_units=self.fixed_env_params.max_units, + num_teams=self.fixed_env_params.num_teams, + map_type=self.fixed_env_params.map_type, + map_width=self.fixed_env_params.map_width, + map_height=self.fixed_env_params.map_height, + max_energy_nodes=self.fixed_env_params.max_energy_nodes, + max_relic_nodes=self.fixed_env_params.max_relic_nodes, + relic_config_size=self.fixed_env_params.relic_config_size, + ) state = self.compute_energy_features(state, params) state = self.compute_sensor_masks(state, params) @@ -416,7 +796,7 @@ def step( state = state_st # Auto-reset environment based on done done = terminated | truncated - + # all agents terminate/truncate at same time terminated_dict = dict() truncated_dict = dict() @@ -426,7 +806,7 @@ def step( info[f"player_{k}"] = dict() return obs, state, reward, terminated_dict, truncated_dict, info - @functools.partial(jax.jit, static_argnums=(0, )) + @functools.partial(jax.jit, static_argnums=(0,)) def reset( self, key: chex.PRNGKey, params: Optional[EnvParams] = None ) -> Tuple[chex.Array, EnvState]: @@ -441,38 +821,62 @@ def reset( def get_obs(self, state: EnvState, params=None, key=None) -> EnvObs: """Return observation from raw state, handling partial observability.""" obs = dict() - + def update_unit_mask(unit_position, unit_mask, sensor_mask): return unit_mask & sensor_mask[unit_position[0], unit_position[1]] + def update_team_unit_mask(unit_position, unit_mask, sensor_mask): - return jax.vmap(update_unit_mask, in_axes=(0, 0, None))(unit_position, unit_mask, sensor_mask) - + return jax.vmap(update_unit_mask, in_axes=(0, 0, None))( + unit_position, unit_mask, sensor_mask + ) + def update_relic_nodes_mask(relic_nodes_mask, relic_nodes, sensor_mask): - return jax.vmap(lambda r_mask, r, s_mask: r_mask & s_mask[r[0], r[1]], in_axes=(0, 0, None))(relic_nodes_mask, relic_nodes, sensor_mask) - - for t in range(params.num_teams): - other_team_ids = jnp.array([t2 for t2 in range(params.num_teams) if t2 != t]) - new_unit_masks = jax.vmap(update_team_unit_mask, in_axes=(0, 0, None))(state.units.position[other_team_ids], state.units_mask[other_team_ids], state.sensor_mask[t]) + return jax.vmap( + lambda r_mask, r, s_mask: r_mask & s_mask[r[0], r[1]], + in_axes=(0, 0, None), + )(relic_nodes_mask, relic_nodes, sensor_mask) + + for t in range(self.fixed_env_params.num_teams): + other_team_ids = jnp.array( + [t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t] + ) + new_unit_masks = jax.vmap(update_team_unit_mask, in_axes=(0, 0, None))( + state.units.position[other_team_ids], + state.units_mask[other_team_ids], + state.sensor_mask[t], + ) new_unit_masks = state.units_mask.at[other_team_ids].set(new_unit_masks) - - new_relic_nodes_mask = update_relic_nodes_mask(state.relic_nodes_mask, state.relic_nodes, state.sensor_mask[t]) + + new_relic_nodes_mask = update_relic_nodes_mask( + state.relic_nodes_mask, state.relic_nodes, state.sensor_mask[t] + ) team_obs = EnvObs( units=UnitState( - position=jnp.where(new_unit_masks[..., None], state.units.position, -1), - energy=jnp.where(new_unit_masks[..., None], state.units.energy, -1)[..., 0], + position=jnp.where( + new_unit_masks[..., None], state.units.position, -1 + ), + energy=jnp.where(new_unit_masks[..., None], state.units.energy, -1)[ + ..., 0 + ], ), units_mask=new_unit_masks, sensor_mask=state.sensor_mask[t], map_features=MapTile( - energy=jnp.where(state.sensor_mask[t], state.map_features.energy, -1), - tile_type=jnp.where(state.sensor_mask[t], state.map_features.tile_type, -1), + energy=jnp.where( + state.sensor_mask[t], state.map_features.energy, -1 + ), + tile_type=jnp.where( + state.sensor_mask[t], state.map_features.tile_type, -1 + ), ), team_points=state.team_points, team_wins=state.team_wins, steps=state.steps, match_steps=state.match_steps, - relic_nodes=jnp.where(new_relic_nodes_mask[..., None], state.relic_nodes, -1), - relic_nodes_mask=new_relic_nodes_mask + relic_nodes=jnp.where( + new_relic_nodes_mask[..., None], state.relic_nodes, -1 + ), + relic_nodes_mask=new_relic_nodes_mask, ) obs[f"player_{t}"] = team_obs return obs @@ -497,7 +901,9 @@ def action_space(self, params: EnvParams): low[:, 1:] = -params.unit_sap_range high = np.ones((params.max_units, 3)) * 6 high[:, 1:] = params.unit_sap_range - return spaces.Dict(dict(player_0=MultiDiscrete(low, high), player_1=MultiDiscrete(low, high))) + return spaces.Dict( + dict(player_0=MultiDiscrete(low, high), player_1=MultiDiscrete(low, high)) + ) def observation_space(self, params: EnvParams): """Observation space of the environment.""" diff --git a/src/luxai_s3/params.py b/src/luxai_s3/params.py index 1417d00..c079a26 100644 --- a/src/luxai_s3/params.py +++ b/src/luxai_s3/params.py @@ -2,6 +2,7 @@ MAP_TYPES = ["dev0", "random"] + @struct.dataclass class EnvParams: max_steps_in_match: int = 100 @@ -21,7 +22,6 @@ class EnvParams: unit_move_cost: int = 2 spawn_rate: int = 5 - unit_sap_cost: int = 10 """ The unit sap cost is the amount of energy a unit uses when it saps another unit. Can change between games. @@ -39,13 +39,11 @@ class EnvParams: The unit energy void factor multiplied by unit_energy """ - # configs for energy nodes max_energy_nodes: int = 6 max_energy_per_tile: int = 20 min_energy_per_tile: int = -20 - max_relic_nodes: int = 6 relic_config_size: int = 5 fog_of_war: bool = True @@ -66,30 +64,30 @@ class EnvParams: The nebula tile vision reduction is the amount of vision reduction a nebula tile provides. A tile can be seen if the vision power over it is > 0. """ - + nebula_tile_energy_reduction: int = 0 """amount of energy nebula tiles reduce from a unit""" - - + nebula_tile_drift_speed: float = -0.05 """ how fast nebula tiles drift in one of the diagonal directions over time. If positive, flows to the top/right, negative flows to bottom/left """ # TODO (stao): allow other kinds of symmetric drifts? - + energy_node_drift_speed: int = 0.02 """ how fast energy nodes will move around over time """ energy_node_drift_magnitude: int = 5 - + # option to change sap configurations + env_params_ranges = dict( - map_type=[1], + # map_type=[1], unit_move_cost=list(range(1, 6)), unit_sensor_range=list(range(2, 5)), - nebula_tile_vision_reduction=list(range(0,4)), + nebula_tile_vision_reduction=list(range(0, 4)), nebula_tile_energy_reduction=[0, 10, 100], unit_sap_cost=list(range(30, 51)), unit_sap_range=list(range(3, 8)), @@ -98,5 +96,5 @@ class EnvParams: # map randomizations nebula_tile_drift_speed=[-0.05, -0.025, 0.025, 0.05], energy_node_drift_speed=[0.01, 0.02, 0.03, 0.04, 0.05], - energy_node_drift_magnitude=list(range(3, 6)) + energy_node_drift_magnitude=list(range(3, 6)), ) diff --git a/src/luxai_s3/pygame_render.py b/src/luxai_s3/pygame_render.py index fe0a304..72c8a60 100644 --- a/src/luxai_s3/pygame_render.py +++ b/src/luxai_s3/pygame_render.py @@ -1,6 +1,7 @@ from luxai_s3.params import EnvParams from luxai_s3.state import ASTEROID_TILE, NEBULA_TILE, EnvState import numpy as np + try: import pygame except: @@ -173,14 +174,28 @@ def draw_rect_alpha(surface, color, rect): if energy_field_value > 0: draw_rect_alpha( self.surface, - (0, 255, 0, 255 * energy_field_value / params.max_energy_per_tile), - pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE), + ( + 0, + 255, + 0, + 255 * energy_field_value / params.max_energy_per_tile, + ), + pygame.Rect( + x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE + ), ) else: draw_rect_alpha( self.surface, - (255, 0, 0, 255 * energy_field_value / params.min_energy_per_tile), - pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE), + ( + 255, + 0, + 0, + 255 * energy_field_value / params.min_energy_per_tile, + ), + pygame.Rect( + x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE + ), ) # if self.display_options["show_vision_power_map"]: # print(state.vision_power_map.shape) diff --git a/src/luxai_s3/spaces.py b/src/luxai_s3/spaces.py index 36c68ad..ea8d6be 100644 --- a/src/luxai_s3/spaces.py +++ b/src/luxai_s3/spaces.py @@ -18,7 +18,8 @@ def __init__(self, low: np.ndarray, high: np.ndarray): def sample(self, rng: chex.PRNGKey) -> chex.Array: return ( - jax.random.uniform(rng, shape=self.shape, minval=0, maxval=1) * self.dist + self.low + jax.random.uniform(rng, shape=self.shape, minval=0, maxval=1) * self.dist + + self.low ).astype(self.dtype) def contains(self, x) -> jnp.ndarray: diff --git a/src/luxai_s3/state.py b/src/luxai_s3/state.py index 6eda387..625adea 100644 --- a/src/luxai_s3/state.py +++ b/src/luxai_s3/state.py @@ -177,24 +177,24 @@ def state_to_flat_obs(state: EnvState) -> chex.Array: def flat_obs_to_state(flat_obs: chex.Array) -> EnvState: pass - -def gen_state(key: chex.PRNGKey, params: EnvParams) -> EnvState: - generated = gen_map(key, params) +@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) +def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_teams: int, map_type: int, map_width: int, map_height: int, max_energy_nodes: int, max_relic_nodes: int, relic_config_size: int) -> EnvState: + generated = gen_map(key, env_params, map_type, map_width, map_height, max_energy_nodes, max_relic_nodes, relic_config_size) relic_nodes_map_weights = jnp.zeros( - shape=(params.map_width, params.map_height), dtype=jnp.int16 + shape=(map_width, map_height), dtype=jnp.int16 ) # TODO (this could be optimized better) def update_relic_node(relic_nodes_map_weights, relic_data): relic_node, relic_node_config, mask = relic_data - start_y = relic_node[1] - params.relic_config_size // 2 - start_x = relic_node[0] - params.relic_config_size // 2 - for dy in range(params.relic_config_size): - for dx in range(params.relic_config_size): + start_y = relic_node[1] - relic_config_size // 2 + start_x = relic_node[0] - relic_config_size // 2 + for dy in range(relic_config_size): + for dx in range(relic_config_size): y, x = start_y + dy, start_x + dx valid_pos = jnp.logical_and( jnp.logical_and(y >= 0, x >= 0), - jnp.logical_and(y < params.map_height, x < params.map_width), + jnp.logical_and(y < map_height, x < map_width), ) relic_nodes_map_weights = jnp.where( valid_pos & mask, @@ -214,12 +214,12 @@ def update_relic_node(relic_nodes_map_weights, relic_data): ), ) state = EnvState( - units=UnitState(position=jnp.zeros(shape=(params.num_teams, params.max_units, 2), dtype=jnp.int16), energy=jnp.zeros(shape=(params.num_teams, params.max_units, 1), dtype=jnp.int16)), + units=UnitState(position=jnp.zeros(shape=(num_teams, max_units, 2), dtype=jnp.int16), energy=jnp.zeros(shape=(num_teams, max_units, 1), dtype=jnp.int16)), units_mask=jnp.zeros( - shape=(params.num_teams, params.max_units), dtype=jnp.bool + shape=(num_teams, max_units), dtype=jnp.bool ), - team_points=jnp.zeros(shape=(params.num_teams), dtype=jnp.int32), - team_wins=jnp.zeros(shape=(params.num_teams), dtype=jnp.int32), + team_points=jnp.zeros(shape=(num_teams), dtype=jnp.int32), + team_wins=jnp.zeros(shape=(num_teams), dtype=jnp.int32), energy_nodes=generated["energy_nodes"], energy_node_fns=generated["energy_node_fns"], energy_nodes_mask=generated["energy_nodes_mask"], @@ -229,57 +229,31 @@ def update_relic_node(relic_nodes_map_weights, relic_data): relic_node_configs=generated["relic_node_configs"], relic_nodes_map_weights=relic_nodes_map_weights, sensor_mask=jnp.zeros( - shape=(params.num_teams, params.map_height, params.map_width), + shape=(num_teams, map_height, map_width), dtype=jnp.bool, ), - vision_power_map=jnp.zeros(shape=(params.num_teams, params.map_height, params.map_width), dtype=jnp.int16), + vision_power_map=jnp.zeros(shape=(num_teams, map_height, map_width), dtype=jnp.int16), map_features=generated["map_features"], ) - - # state = spawn_unit(state, 0, 0, [0, 0], params) - # state = spawn_unit(state, 0, 1, [0, 0], params) - # state = spawn_unit(state, 0, 2, [0, 0]) - # state = spawn_unit(state, 1, 0, [15, 15], params) - # state = spawn_unit(state, 1, 1, [15, 15], params) - # state = spawn_unit(state, 1, 2, [15, 15]) - return state - - -def spawn_unit( - state: EnvState, team: int, unit_id: int, position: chex.Array, params: EnvParams -) -> EnvState: - unit_state = state.units - unit_state = unit_state.replace(position=unit_state.position.at[team, unit_id, :].set(jnp.array(position, dtype=jnp.int16))) - unit_state = unit_state.replace(energy=unit_state.energy.at[team, unit_id, :].set(jnp.array([params.init_unit_energy], dtype=jnp.int16))) - # state = state.replace( - # units - # # units=state.units.at[team, unit_id, :].set( - # # jnp.array([position[0], position[1], 0], dtype=jnp.int16) - # # ) - # ) - state = state.replace(units=unit_state, units_mask=state.units_mask.at[team, unit_id].set(True)) return state -def set_tile(map_features: MapTile, x: int, y: int, tile_type: int) -> MapTile: - return map_features.replace(tile_type=map_features.tile_type.at[x, y].set(tile_type)) - -# @functools.partial(jax.jit, static_argnums=(1,)) -def gen_map(key: chex.PRNGKey, params: EnvParams) -> chex.Array: +@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7)) +def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int, map_width: int, max_energy_nodes: int, max_relic_nodes: int, relic_config_size: int) -> chex.Array: map_features = MapTile(energy=jnp.zeros( - shape=(params.map_height, params.map_width), dtype=jnp.int16 + shape=(map_height, map_width), dtype=jnp.int16 ), tile_type=jnp.zeros( - shape=(params.map_height, params.map_width), dtype=jnp.int16 + shape=(map_height, map_width), dtype=jnp.int16 )) - energy_nodes = jnp.zeros(shape=(params.max_energy_nodes, 2), dtype=jnp.int16) - energy_nodes_mask = jnp.zeros(shape=(params.max_energy_nodes), dtype=jnp.bool) - relic_nodes = jnp.zeros(shape=(params.max_relic_nodes, 2), dtype=jnp.int16) - relic_nodes_mask = jnp.zeros(shape=(params.max_relic_nodes), dtype=jnp.bool) + energy_nodes = jnp.zeros(shape=(max_energy_nodes, 2), dtype=jnp.int16) + energy_nodes_mask = jnp.zeros(shape=(max_energy_nodes), dtype=jnp.bool) + relic_nodes = jnp.zeros(shape=(max_relic_nodes, 2), dtype=jnp.int16) + relic_nodes_mask = jnp.zeros(shape=(max_relic_nodes), dtype=jnp.bool) - if MAP_TYPES[params.map_type] == "random": + if MAP_TYPES[map_type] == "random": ### Generate nebula tiles ### key, subkey = jax.random.split(key) - perlin_noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4)) + perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4)) noise = jnp.where(perlin_noise > 0.5, 1, 0) # mirror along diagonal noise = noise | noise.T @@ -288,7 +262,7 @@ def gen_map(key: chex.PRNGKey, params: EnvParams) -> chex.Array: ### Generate asteroid tiles ### key, subkey = jax.random.split(key) - perlin_noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (8, 8)) + perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (8, 8)) noise = jnp.where(perlin_noise < -0.5, 1, 0) # mirror along diagonal noise = noise | noise.T @@ -297,9 +271,9 @@ def gen_map(key: chex.PRNGKey, params: EnvParams) -> chex.Array: ### Generate relic nodes ### key, subkey = jax.random.split(key) - noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4)) + noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4)) # Find the positions of the highest noise values - flat_indices = jnp.argsort(noise.ravel())[-params.max_relic_nodes // 2:] # Get indices of two highest values + flat_indices = jnp.argsort(noise.ravel())[-max_relic_nodes // 2:] # Get indices of two highest values highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)) # relic nodes have a fixed density of 25% nearby tiles can yield points @@ -307,9 +281,9 @@ def gen_map(key: chex.PRNGKey, params: EnvParams) -> chex.Array: jax.random.randint( key, shape=( - params.max_relic_nodes, - params.relic_config_size, - params.relic_config_size, + max_relic_nodes, + relic_config_size, + relic_config_size, ), minval=0, maxval=10, @@ -320,30 +294,30 @@ def gen_map(key: chex.PRNGKey, params: EnvParams) -> chex.Array: highest_positions = highest_positions.astype(jnp.int16) relic_nodes_mask = relic_nodes_mask.at[0].set(True) relic_nodes_mask = relic_nodes_mask.at[1].set(True) - mirrored_positions = jnp.stack([params.map_width - highest_positions[:, 1] - 1, params.map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1) + mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1) relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0) key, subkey = jax.random.split(key) - relic_nodes_mask_half = jax.random.randint(key, (params.max_relic_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool) + relic_nodes_mask_half = jax.random.randint(key, (max_relic_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool) relic_nodes_mask_half = relic_nodes_mask_half.at[0].set(True) - relic_nodes_mask = relic_nodes_mask.at[:params.max_relic_nodes // 2].set(relic_nodes_mask_half) - relic_nodes_mask = relic_nodes_mask.at[params.max_relic_nodes // 2:].set(relic_nodes_mask_half) + relic_nodes_mask = relic_nodes_mask.at[:max_relic_nodes // 2].set(relic_nodes_mask_half) + relic_nodes_mask = relic_nodes_mask.at[max_relic_nodes // 2:].set(relic_nodes_mask_half) # import ipdb;ipdb.set_trace() - relic_node_configs = relic_node_configs.at[params.max_relic_nodes // 2:].set(relic_node_configs[:params.max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1]) + relic_node_configs = relic_node_configs.at[max_relic_nodes // 2:].set(relic_node_configs[:max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1]) ### Generate energy nodes ### key, subkey = jax.random.split(key) - noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4)) + noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4)) # Find the positions of the highest noise values - flat_indices = jnp.argsort(noise.ravel())[-params.max_energy_nodes // 2:] # Get indices of highest values + flat_indices = jnp.argsort(noise.ravel())[-max_energy_nodes // 2:] # Get indices of highest values highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)) - mirrored_positions = jnp.stack([params.map_width - highest_positions[:, 1] - 1, params.map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1) + mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1) energy_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0) key, subkey = jax.random.split(key) - energy_nodes_mask_half = jax.random.randint(key, (params.max_energy_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool) + energy_nodes_mask_half = jax.random.randint(key, (max_energy_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool) energy_nodes_mask_half = energy_nodes_mask_half.at[0].set(True) - energy_nodes_mask = energy_nodes_mask.at[:params.max_energy_nodes // 2].set(energy_nodes_mask_half) - energy_nodes_mask = energy_nodes_mask.at[params.max_energy_nodes // 2:].set(energy_nodes_mask_half) + energy_nodes_mask = energy_nodes_mask.at[:max_energy_nodes // 2].set(energy_nodes_mask_half) + energy_nodes_mask = energy_nodes_mask.at[max_energy_nodes // 2:].set(energy_nodes_mask_half) # TODO (stao): provide more randomization options for energy node functions. energy_node_fns = jnp.array( diff --git a/src/luxai_s3/utils.py b/src/luxai_s3/utils.py index 0d19cdb..9e56360 100644 --- a/src/luxai_s3/utils.py +++ b/src/luxai_s3/utils.py @@ -1,4 +1,6 @@ import numpy as np + + def to_numpy(x): if isinstance(x, dict): return {k: to_numpy(v) for k, v in x.items()} @@ -7,4 +9,4 @@ def to_numpy(x): elif isinstance(x, np.ndarray): return x else: - return np.array(x) \ No newline at end of file + return np.array(x) diff --git a/src/luxai_s3/wrappers.py b/src/luxai_s3/wrappers.py index eb30b97..56ee245 100644 --- a/src/luxai_s3/wrappers.py +++ b/src/luxai_s3/wrappers.py @@ -15,13 +15,14 @@ from luxai_s3.state import serialize_env_actions, serialize_env_states from luxai_s3.utils import to_numpy + class LuxAIS3GymEnv(gym.Env): def __init__(self, numpy_output: bool = False): self.numpy_output = numpy_output self.rng_key = jax.random.key(0) self.jax_env = LuxAIS3Env(auto_reset=False) self.env_params: EnvParams = EnvParams() - + # auto run compiling steps here: # print("Running compilation steps") key = jax.random.key(0) @@ -43,16 +44,19 @@ def __init__(self, numpy_output: bool = False): low[:, 1:] = -self.env_params.unit_sap_range high = np.ones((self.env_params.max_units, 3)) * 6 high[:, 1:] = self.env_params.unit_sap_range - self.action_space = gym.spaces.Dict(dict( - player_0=gym.spaces.Box(low=low, high=high, dtype=np.int16), - player_1=gym.spaces.Box(low=low, high=high, dtype=np.int16) - )) - + self.action_space = gym.spaces.Dict( + dict( + player_0=gym.spaces.Box(low=low, high=high, dtype=np.int16), + player_1=gym.spaces.Box(low=low, high=high, dtype=np.int16), + ) + ) + def render(self): self.jax_env.render(self.state, self.env_params) - - - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]: + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Any, dict[str, Any]]: if seed is not None: self.rng_key = jax.random.key(seed) self.rng_key, reset_key = jax.random.split(self.rng_key) @@ -61,26 +65,45 @@ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = Non randomized_game_params = dict() for k, v in env_params_ranges.items(): self.rng_key, subkey = jax.random.split(self.rng_key) - randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v)).item() + randomized_game_params[k] = jax.random.choice( + subkey, jax.numpy.array(v) + ).item() params = EnvParams(**randomized_game_params) if options is not None and "params" in options: params = options["params"] - + self.env_params = params obs, self.state = self.jax_env.reset(reset_key, params=params) if self.numpy_output: obs = to_numpy(flax.serialization.to_state_dict(obs)) - + # only keep the following game parameters available to the agent params_dict = dataclasses.asdict(params) params_dict_kept = dict() - for k in ["max_units", "match_count_per_episode", "max_steps_in_match", "map_height", "map_width", "num_teams", "unit_move_cost", "unit_sap_cost", "unit_sap_range", "unit_sensor_range"]: + for k in [ + "max_units", + "match_count_per_episode", + "max_steps_in_match", + "map_height", + "map_width", + "num_teams", + "unit_move_cost", + "unit_sap_cost", + "unit_sap_range", + "unit_sensor_range", + ]: params_dict_kept[k] = params_dict[k] - return obs, dict(params=params_dict_kept, full_params=params_dict, state=self.state) - - def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: + return obs, dict( + params=params_dict_kept, full_params=params_dict, state=self.state + ) + + def step( + self, action: Any + ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: self.rng_key, step_key = jax.random.split(self.rng_key) - obs, self.state, reward, terminated, truncated, info = self.jax_env.step(step_key, self.state, action, self.env_params) + obs, self.state, reward, terminated, truncated, info = self.jax_env.step( + step_key, self.state, action, self.env_params + ) if self.numpy_output: obs = to_numpy(flax.serialization.to_state_dict(obs)) reward = to_numpy(reward) @@ -89,10 +112,18 @@ def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, A # info = to_numpy(flax.serialization.to_state_dict(info)) return obs, reward, terminated, truncated, info + # TODO: vectorized gym wrapper + class RecordEpisode(gym.Wrapper): - def __init__(self, env: LuxAIS3GymEnv, save_dir: str = None, save_on_close: bool = True, save_on_reset: bool = True): + def __init__( + self, + env: LuxAIS3GymEnv, + save_dir: str = None, + save_on_close: bool = True, + save_on_reset: bool = True, + ): super().__init__(env) self.episode = dict(states=[], actions=[], metadata=dict()) self.episode_id = 0 @@ -102,24 +133,30 @@ def __init__(self, env: LuxAIS3GymEnv, save_dir: str = None, save_on_close: bool self.episode_steps = 0 if save_dir is not None: from pathlib import Path + Path(save_dir).mkdir(parents=True, exist_ok=True) - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]: + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Any, dict[str, Any]]: if self.save_on_reset and self.episode_steps > 0: self._save_episode_and_reset() obs, info = self.env.reset(seed=seed, options=options) - + self.episode["metadata"]["seed"] = seed self.episode["params"] = flax.serialization.to_state_dict(info["full_params"]) self.episode["states"].append(info["state"]) return obs, info - def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: + + def step( + self, action: Any + ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: obs, reward, terminated, truncated, info = self.env.step(action) self.episode_steps += 1 self.episode["states"].append(info["final_state"]) self.episode["actions"].append(action) return obs, reward, terminated, truncated, info - + def serialize_episode_data(self, episode=None): if episode is None: episode = self.episode @@ -130,20 +167,21 @@ def serialize_episode_data(self, episode=None): ret["metadata"] = episode["metadata"] ret["params"] = episode["params"] return ret - + def save_episode(self, save_path: str): episode = self.serialize_episode_data() with open(save_path, "w") as f: json.dump(episode, f) self.episode = dict(states=[], actions=[], metadata=dict()) - + def _save_episode_and_reset(self): """saves to generated path based on self.save_dir and episoe id and updates relevant counters""" - self.save_episode(os.path.join(self.save_dir, f"episode_{self.episode_id}.json")) + self.save_episode( + os.path.join(self.save_dir, f"episode_{self.episode_id}.json") + ) self.episode_id += 1 self.episode_steps = 0 - + def close(self): if self.save_on_close and self.episode_steps > 0: self._save_episode_and_reset() - \ No newline at end of file diff --git a/src/tests/test.py b/src/tests/test.py index d4a7762..2e18126 100644 --- a/src/tests/test.py +++ b/src/tests/test.py @@ -11,6 +11,7 @@ import jax.numpy as jnp from luxai_s3.env import LuxAIS3Env + # from luxai_s3.wrappers import RecordEpisode # Create the environment @@ -33,8 +34,6 @@ subkey, state, action, params=env_params ) - - states = [] actions = [] key = jax.random.key(0) @@ -58,20 +57,17 @@ save_start_time = time.time() - states = serialize_env_states(states) - episode=dict(observations=states, actions=serialize_env_actions(actions)) + episode = dict(observations=states, actions=serialize_env_actions(actions)) episode["params"] = flax.serialization.to_state_dict(env_params) - episode["metadata"] = dict( - seed=0 - ) - + episode["metadata"] = dict(seed=0) + # jax.random.PRNGKey(episode["seed"]) # obs, state = env.reset(jax.random.wrap_key_data(episode["seed"]), params=env_params) with open("../lux-eye/src/pages/home/episode.json", "w") as f: json.dump(episode, f) save_end_time = time.time() - + save_duration = save_end_time - save_start_time print(f"Time taken to save episode: {save_duration:.4f} seconds") - # import ipdb; ipdb.set_trace() \ No newline at end of file + # import ipdb; ipdb.set_trace() diff --git a/src/tests/test_gpu.py b/src/tests/test_gpu.py index f8ab20a..d1a2f25 100644 --- a/src/tests/test_gpu.py +++ b/src/tests/test_gpu.py @@ -6,15 +6,20 @@ from luxai_s3.params import env_params_ranges from luxai_s3.state import gen_map from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode + if __name__ == "__main__": import numpy as np + np.random.seed(2) - - jax_env = LuxAIS3Env(auto_reset=True) + + # the first env params is not batched and is used to initialize any static / unchaging values + # like map size, max units etc. + jax_env = LuxAIS3Env(auto_reset=True, fixed_env_params=EnvParams()) num_envs = 10 seed = 0 rng_key = jax.random.key(seed) reset_fn = jax.vmap(jax_env.reset_env) + # sample random params initially def sample_params(rng_key): randomized_game_params = dict() @@ -23,25 +28,36 @@ def sample_params(rng_key): randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v)) params = EnvParams(**randomized_game_params) return params - + rng_key, subkey = jax.random.split(rng_key) env_params = jax.vmap(sample_params)(jax.random.split(subkey, num_envs)) - # reset_fn(jax.random.split(subkey, num_envs), env_params) - jax.vmap(gen_map)(jax.random.split(subkey, num_envs), env_params) + # + # env_params = EnvParams() + # for k in env_params_ranges: + # env_params = env_params.replace(**{k: getattr(sampled_env_params, k)}) + # env_params = [EnvParams() for _ in range(num_envs)] + res = reset_fn(jax.random.split(subkey, num_envs), env_params) + # def gen_map(key, params): + # # jax.debug.breakpoint() + # # import ipdb; ipdb.set_trace() + # jax.numpy.zeros((params.map_height, 24)) + # return params.map_height+params.map_width + # gen_map = jax.jit(gen_map, static_argnums=(0,2,3,4,5,6, 7)) + # import ipdb; ipdb.set_trace() + # res = jax.vmap(gen_map, in_axes=(0, 0, None, None, None, None, None, None))(jax.random.split(subkey, num_envs), env_params, fixed_env_params.map_type, fixed_env_params.map_height, fixed_env_params.map_width, fixed_env_params.max_energy_nodes, fixed_env_params.max_relic_nodes, fixed_env_params.relic_config_size) - # env = LuxAIS3GymEnv() # env = RecordEpisode(env, save_dir="episodes") # obs, info = env.reset(seed=1) - - # print("Benchmarking time") - # stime = time.time() - # N = 100 - # # N = env.params.max_steps_in_match * env.params.match_count_per_episode - # for _ in range(N): - # env.step(env.action_space.sample()) - # etime = time.time() - # print(f"FPS: {N / (etime - stime)}") - - # env.close() \ No newline at end of file + + print("Benchmarking time") + stime = time.time() + N = 100 + N = fixed_env_params.max_steps_in_match * fixed_env_params.match_count_per_episode + for _ in range(N): + env.step(env.action_space.sample()) + etime = time.time() + print(f"FPS: {N / (etime - stime)}") + + # env.close() diff --git a/src/tests/test_gym.py b/src/tests/test_gym.py index 2db5810..5622a1f 100644 --- a/src/tests/test_gym.py +++ b/src/tests/test_gym.py @@ -4,14 +4,16 @@ from luxai_s3.params import EnvParams from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode + if __name__ == "__main__": import numpy as np + np.random.seed(2) env = LuxAIS3GymEnv() env = RecordEpisode(env, save_dir="episodes") env_params = EnvParams(map_type=0, max_steps_in_match=100) obs, info = env.reset(seed=1, options=dict(params=env_params)) - + print("Benchmarking time") stime = time.time() N = env_params.max_steps_in_match * env_params.match_count_per_episode @@ -19,5 +21,5 @@ env.step(env.action_space.sample()) etime = time.time() print(f"FPS: {N / (etime - stime)}") - - env.close() \ No newline at end of file + + env.close() From 971863cc4176fa698b0253f64a4ac886344eac7b Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Fri, 25 Oct 2024 10:29:40 -0700 Subject: [PATCH 04/11] fixes --- src/luxai_s3/env.py | 66 ++++++++++++++++++++++--------------------- src/luxai_s3/state.py | 7 ++--- src/tests/test_gpu.py | 51 +++++++++++++++------------------ 3 files changed, 60 insertions(+), 64 deletions(-) diff --git a/src/luxai_s3/env.py b/src/luxai_s3/env.py index 3a2072b..f8b7d8f 100644 --- a/src/luxai_s3/env.py +++ b/src/luxai_s3/env.py @@ -41,7 +41,7 @@ def default_params(self) -> EnvParams: def compute_unit_counts_map(self, state: EnvState, params: EnvParams): # map of total units per team on each tile, shape (num_teams, map_width, map_height) unit_counts_map = jnp.zeros( - (params.num_teams, params.map_width, params.map_height), dtype=jnp.int32 + (self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height), dtype=jnp.int32 ) def update_unit_counts_map(unit_position, unit_mask, unit_counts_map): @@ -50,7 +50,7 @@ def update_unit_counts_map(unit_position, unit_mask, unit_counts_map): ].add(unit_mask) return unit_counts_map - for t in range(params.num_teams): + for t in range(self.fixed_env_params.num_teams): unit_counts_map = unit_counts_map.at[t].add( jnp.sum( jax.vmap(update_unit_counts_map, in_axes=(0, 0, None), out_axes=0)( @@ -64,7 +64,7 @@ def update_unit_counts_map(unit_position, unit_mask, unit_counts_map): def compute_energy_features(self, state: EnvState, params: EnvParams): # first compute a array of shape (map_height, map_width, num_energy_nodes) with values equal to the distance of the tile to the energy node mm = jnp.meshgrid(jnp.arange(self.fixed_env_params.map_width), jnp.arange(self.fixed_env_params.map_height)) - mm = jnp.stack([mm[0], mm[1]]).T # mm[x, y] gives [x, y] + mm = jnp.stack([mm[0], mm[1]]).T.astype(jnp.int16) # mm[x, y] gives [x, y] distances_to_nodes = jax.vmap(lambda pos: jnp.linalg.norm(mm - pos, axis=-1))( state.energy_nodes ) @@ -136,7 +136,7 @@ def update_vision_power_map(unit_pos, vision_power_map): ) update = jnp.zeros_like(existing_vision_power) for i in range(max_sensor_range + 1): - val = jnp.where(i > max_sensor_range - params.unit_sensor_range, i + 1 - params.unit_sensor_range, 0) + val = jnp.where(i > max_sensor_range - params.unit_sensor_range, i + 1 - params.unit_sensor_range, 0).astype(jnp.int16) update = update.at[ i : max_sensor_range * 2 + 1 - i, i : max_sensor_range * 2 + 1 - i, @@ -290,9 +290,9 @@ def sap_unit( units_mask, ): # TODO (stao): clean up this code. It is probably slower than it needs be and could be vmapped perhaps. - for t in range(params.num_teams): + for t in range(self.fixed_env_params.num_teams): other_team_ids = jnp.array( - [t2 for t2 in range(params.num_teams) if t2 != t] + [t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t] ) team_sap_action_deltas = sap_action_deltas[t] # (max_units, 2) team_sap_action_mask = sap_action_mask[t] @@ -315,8 +315,8 @@ def sap_unit( team_unit_sapped = ( team_unit_sapped & (team_sapped_positions >= 0).all(-1) - & (team_sapped_positions[:, 0] < params.map_width) - & (team_sapped_positions[:, 1] < params.map_height) + & (team_sapped_positions[:, 0] < self.fixed_env_params.map_width) + & (team_sapped_positions[:, 1] < self.fixed_env_params.map_height) ) # the number of times other units are sapped other_units_sapped_count = jnp.sum( @@ -412,16 +412,16 @@ def sap_unit( # compute energy void fields for all teams and the energy + unit counts unit_aggregate_energy_void_map = jnp.zeros( - shape=(params.num_teams, params.map_width, params.map_height), + shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height), dtype=jnp.float32, ) unit_counts_map = self.compute_unit_counts_map(state, params) # TODO (stao): this doesn't need to be a float? unit_aggregate_energy_map = jnp.zeros( - shape=(params.num_teams, params.map_width, params.map_height), + shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height), dtype=jnp.float32, ) - for t in range(params.num_teams): + for t in range(self.fixed_env_params.num_teams): def scan_body(carry, x): agg_energy_void_map, agg_energy_map = carry @@ -433,9 +433,9 @@ def scan_body(carry, x): new_pos = unit_position + jnp.array(deltas) in_map = ( (new_pos[0] >= 0) - & (new_pos[0] < params.map_width) + & (new_pos[0] < self.fixed_env_params.map_width) & (new_pos[1] >= 0) - & (new_pos[1] < params.map_height) + & (new_pos[1] < self.fixed_env_params.map_height) ) agg_energy_void_map = agg_energy_void_map.at[ new_pos[0], new_pos[1] @@ -455,9 +455,9 @@ def scan_body(carry, x): ) # resolve collisions and keep only the surviving units - for t in range(params.num_teams): + for t in range(self.fixed_env_params.num_teams): other_team_ids = jnp.array( - [t2 for t2 in range(params.num_teams) if t2 != t] + [t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t] ) # get the energy map for the current team opposing_unit_counts_map = unit_counts_map[other_team_ids].sum( @@ -485,9 +485,9 @@ def scan_body(carry, x): ) ) # apply energy void fields - for t in range(params.num_teams): + for t in range(self.fixed_env_params.num_teams): other_team_ids = jnp.array( - [t2 for t2 in range(params.num_teams) if t2 != t] + [t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t] ) oppposition_energy_void_map = unit_aggregate_energy_void_map[ other_team_ids @@ -650,7 +650,7 @@ def spawn_team_units(state: EnvState): energy_node_deltas = jnp.round( jax.random.uniform( key=key, - shape=(params.max_energy_nodes // 2, 2), + shape=(self.fixed_env_params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude, ) @@ -666,7 +666,7 @@ def spawn_team_units(state: EnvState): new_energy_nodes = jnp.clip( state.energy_nodes + energy_node_deltas, min=jnp.array([0, 0]), - max=jnp.array([params.map_width, params.map_height]), + max=jnp.array([self.fixed_env_params.map_width, self.fixed_env_params.map_height]), ) new_energy_nodes = jnp.where( state.steps * params.energy_node_drift_speed % 1 == 0, @@ -732,7 +732,7 @@ def team_relic_score(unit_counts_map): >= (params.max_steps_in_match + 1) * params.match_count_per_episode ) reward = dict() - for k in range(params.num_teams): + for k in range(self.fixed_env_params.num_teams): reward[f"player_{k}"] = state.team_wins[k] terminated = self.is_terminal(state, params) return ( @@ -766,7 +766,7 @@ def reset_env( return self.get_obs(state, params=params, key=key), state - @functools.partial(jax.jit, static_argnums=(0, 4)) + @functools.partial(jax.jit, static_argnums=(0,)) def step( self, key: chex.PRNGKey, @@ -787,10 +787,11 @@ def step( if self.auto_reset: done = terminated | truncated obs_re, state_re = self.reset_env(key_reset, params) - state = jax.tree_map( - lambda x, y: jax.lax.select(done, x, y), state_re, state_st - ) - obs = jax.lax.select(done, obs_re, obs_st) + # state = jax.tree_map( + # lambda x, y: jax.lax.select(done, x, y), state_re, state_st + # ) + obs = obs_st + # obs = jax.lax.select(done, obs_re, obs_st) else: obs = obs_st state = state_st @@ -800,7 +801,7 @@ def step( # all agents terminate/truncate at same time terminated_dict = dict() truncated_dict = dict() - for k in range(params.num_teams): + for k in range(self.fixed_env_params.num_teams): terminated_dict[f"player_{k}"] = terminated truncated_dict[f"player_{k}"] = truncated info[f"player_{k}"] = dict() @@ -814,6 +815,7 @@ def reset( # Use default env parameters if no others specified if params is None: params = self.default_params + obs, state = self.reset_env(key, params) return obs, state @@ -879,7 +881,7 @@ def update_relic_nodes_mask(relic_nodes_mask, relic_nodes, sensor_mask): relic_nodes_mask=new_relic_nodes_mask, ) obs[f"player_{t}"] = team_obs - return obs + return 3 @functools.partial(jax.jit, static_argnums=(0, 2)) def is_terminal(self, state: EnvState, params: EnvParams) -> jnp.ndarray: @@ -895,12 +897,12 @@ def name(self) -> str: def render(self, state: EnvState, params: EnvParams): self.renderer.render(state, params) - def action_space(self, params: EnvParams): + def action_space(self, params: Optional[EnvParams] = None): """Action space of the environment.""" - low = np.zeros((params.max_units, 3)) - low[:, 1:] = -params.unit_sap_range - high = np.ones((params.max_units, 3)) * 6 - high[:, 1:] = params.unit_sap_range + low = np.zeros((self.fixed_env_params.max_units, 3)) + low[:, 1:] = -env_params_ranges["unit_sap_range"][-1] + high = np.ones((self.fixed_env_params.max_units, 3)) * 6 + high[:, 1:] = env_params_ranges["unit_sap_range"][-1] return spaces.Dict( dict(player_0=MultiDiscrete(low, high), player_1=MultiDiscrete(low, high)) ) diff --git a/src/luxai_s3/state.py b/src/luxai_s3/state.py index 625adea..f5bf843 100644 --- a/src/luxai_s3/state.py +++ b/src/luxai_s3/state.py @@ -198,7 +198,7 @@ def update_relic_node(relic_nodes_map_weights, relic_data): ) relic_nodes_map_weights = jnp.where( valid_pos & mask, - relic_nodes_map_weights.at[x, y].add(relic_node_config[dx, dy]), + relic_nodes_map_weights.at[x, y].add(relic_node_config[dx, dy].astype(jnp.int16)), relic_nodes_map_weights, ) return relic_nodes_map_weights, None @@ -287,8 +287,7 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int ), minval=0, maxval=10, - dtype=jnp.int16, - ) + ).astype(jnp.float32) >= 7.5 ) highest_positions = highest_positions.astype(jnp.int16) @@ -310,7 +309,7 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4)) # Find the positions of the highest noise values flat_indices = jnp.argsort(noise.ravel())[-max_energy_nodes // 2:] # Get indices of highest values - highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)) + highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)).astype(jnp.int16) mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1) energy_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0) key, subkey = jax.random.split(key) diff --git a/src/tests/test_gpu.py b/src/tests/test_gpu.py index d1a2f25..7dfd81a 100644 --- a/src/tests/test_gpu.py +++ b/src/tests/test_gpu.py @@ -9,16 +9,18 @@ if __name__ == "__main__": import numpy as np + # jax.config.update('jax_numpy_dtype_promotion', 'strict') np.random.seed(2) # the first env params is not batched and is used to initialize any static / unchaging values # like map size, max units etc. - jax_env = LuxAIS3Env(auto_reset=True, fixed_env_params=EnvParams()) - num_envs = 10 + env = LuxAIS3Env(auto_reset=True, fixed_env_params=EnvParams()) + num_envs = 100 seed = 0 rng_key = jax.random.key(seed) - reset_fn = jax.vmap(jax_env.reset_env) + reset_fn = jax.vmap(env.reset) + step_fn = jax.vmap(env.step) # sample random params initially def sample_params(rng_key): @@ -31,33 +33,26 @@ def sample_params(rng_key): rng_key, subkey = jax.random.split(rng_key) env_params = jax.vmap(sample_params)(jax.random.split(subkey, num_envs)) - # - # env_params = EnvParams() - # for k in env_params_ranges: - # env_params = env_params.replace(**{k: getattr(sampled_env_params, k)}) - # env_params = [EnvParams() for _ in range(num_envs)] - res = reset_fn(jax.random.split(subkey, num_envs), env_params) - - # def gen_map(key, params): - # # jax.debug.breakpoint() - # # import ipdb; ipdb.set_trace() - # jax.numpy.zeros((params.map_height, 24)) - # return params.map_height+params.map_width - # gen_map = jax.jit(gen_map, static_argnums=(0,2,3,4,5,6, 7)) - # import ipdb; ipdb.set_trace() - # res = jax.vmap(gen_map, in_axes=(0, 0, None, None, None, None, None, None))(jax.random.split(subkey, num_envs), env_params, fixed_env_params.map_type, fixed_env_params.map_height, fixed_env_params.map_width, fixed_env_params.max_energy_nodes, fixed_env_params.max_relic_nodes, fixed_env_params.relic_config_size) - - # env = LuxAIS3GymEnv() - # env = RecordEpisode(env, save_dir="episodes") - # obs, info = env.reset(seed=1) + action_space = env.action_space() # note that this can generate sap actions beyond range atm + sample_action = jax.vmap(action_space.sample) + obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) + obs, state, reward, terminated_dict, truncated_dict, info = step_fn( + jax.random.split(subkey, num_envs), + state, + sample_action(jax.random.split(subkey, num_envs)), + env_params + ) print("Benchmarking time") stime = time.time() - N = 100 - N = fixed_env_params.max_steps_in_match * fixed_env_params.match_count_per_episode + N = env.fixed_env_params.max_steps_in_match * env.fixed_env_params.match_count_per_episode + obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) for _ in range(N): - env.step(env.action_space.sample()) + obs, state, reward, terminated_dict, truncated_dict, info = step_fn( + jax.random.split(subkey, num_envs), + state, + sample_action(jax.random.split(subkey, num_envs)), + env_params + ) etime = time.time() - print(f"FPS: {N / (etime - stime)}") - - # env.close() + print(f"FPS: {N * num_envs / (etime - stime):0.3f}. {N / (etime - stime):0.3f} parallel steps/s") From cb7e5a7c2c1fb42ea6e85b2224ccf6c1bd095632 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Sat, 26 Oct 2024 07:00:26 -0700 Subject: [PATCH 05/11] Update src/luxai_s3/env.py Co-authored-by: CJBoey <19331431+CJBoey@users.noreply.github.com> --- src/luxai_s3/env.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/luxai_s3/env.py b/src/luxai_s3/env.py index f8b7d8f..28c79c3 100644 --- a/src/luxai_s3/env.py +++ b/src/luxai_s3/env.py @@ -109,7 +109,8 @@ def compute_sensor_masks(self, state, params: EnvParams): Now any time the vision power map has value > 0, the team can sense the tile. This forms the sensor mask """ - vision_power_map_padding = self.fixed_env_params.unit_sensor_range + max_sensor_range = env_params_ranges["unit_sensor_range"][-1] + vision_power_map_padding = max_sensor_range vision_power_map = jnp.zeros( shape=( self.fixed_env_params.num_teams, From e0d81f9784935da0b2dab86d62d16a93f44ba115 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Sat, 26 Oct 2024 07:00:38 -0700 Subject: [PATCH 06/11] Update src/luxai_s3/env.py Co-authored-by: CJBoey <19331431+CJBoey@users.noreply.github.com> --- src/luxai_s3/env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/luxai_s3/env.py b/src/luxai_s3/env.py index 28c79c3..18db910 100644 --- a/src/luxai_s3/env.py +++ b/src/luxai_s3/env.py @@ -121,7 +121,6 @@ def compute_sensor_masks(self, state, params: EnvParams): ) # Update sensor mask based on the sensor range - max_sensor_range = env_params_ranges["unit_sensor_range"][-1] def update_vision_power_map(unit_pos, vision_power_map): x, y = unit_pos existing_vision_power = jax.lax.dynamic_slice( From 21ca2deb2d7320529e1a36156a45e82bda40faf9 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Sat, 26 Oct 2024 07:00:55 -0700 Subject: [PATCH 07/11] Update src/luxai_s3/env.py Co-authored-by: CJBoey <19331431+CJBoey@users.noreply.github.com> --- src/luxai_s3/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/luxai_s3/env.py b/src/luxai_s3/env.py index 18db910..2e1ffae 100644 --- a/src/luxai_s3/env.py +++ b/src/luxai_s3/env.py @@ -136,7 +136,7 @@ def update_vision_power_map(unit_pos, vision_power_map): ) update = jnp.zeros_like(existing_vision_power) for i in range(max_sensor_range + 1): - val = jnp.where(i > max_sensor_range - params.unit_sensor_range, i + 1 - params.unit_sensor_range, 0).astype(jnp.int16) + val = jnp.where(i > max_sensor_range - params.unit_sensor_range - 1, i + 1 - params.unit_sensor_range, 0).astype(jnp.int16) update = update.at[ i : max_sensor_range * 2 + 1 - i, i : max_sensor_range * 2 + 1 - i, From b75b4e89bffef9d91b3c0a16532dfd4d0e4bdc8b Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Sat, 26 Oct 2024 07:41:07 -0700 Subject: [PATCH 08/11] fix type promotion errors --- src/luxai_runner/cli.py | 4 +-- src/luxai_s3/env.py | 66 ++++++++++++++++++++++------------------- src/luxai_s3/params.py | 1 + src/tests/test_gpu.py | 19 ++++++++---- 4 files changed, 52 insertions(+), 38 deletions(-) diff --git a/src/luxai_runner/cli.py b/src/luxai_runner/cli.py index f37b455..1cdf545 100644 --- a/src/luxai_runner/cli.py +++ b/src/luxai_runner/cli.py @@ -2,7 +2,7 @@ import json import sys from pathlib import Path -from typing import Dict, List +from typing import Annotated, Dict, List import numpy as np from luxai_runner.bot import Bot @@ -30,7 +30,7 @@ class Args: """Paths to player modules. If --tournament is passed as well, you can also pass a folder and we will look through all sub-folders for valid agents with main.py files (only works for python agents at the moment).""" len: Optional[int] = 1000 """Max episode length""" - output: Optional[str] = None + output: Annotated[Optional[str], tyro.conf.arg(aliases=["-o"])] = None """Where to output replays. Default is none and no replay is generated""" replay: ReplayConfig = field(default_factory=lambda: ReplayConfig()) diff --git a/src/luxai_s3/env.py b/src/luxai_s3/env.py index 2e1ffae..488407f 100644 --- a/src/luxai_s3/env.py +++ b/src/luxai_s3/env.py @@ -41,13 +41,13 @@ def default_params(self) -> EnvParams: def compute_unit_counts_map(self, state: EnvState, params: EnvParams): # map of total units per team on each tile, shape (num_teams, map_width, map_height) unit_counts_map = jnp.zeros( - (self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height), dtype=jnp.int32 + (self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height), dtype=jnp.int16 ) def update_unit_counts_map(unit_position, unit_mask, unit_counts_map): unit_counts_map = unit_counts_map.at[ unit_position[0], unit_position[1] - ].add(unit_mask) + ].add(unit_mask.astype(jnp.int16)) return unit_counts_map for t in range(self.fixed_env_params.num_teams): @@ -57,6 +57,7 @@ def update_unit_counts_map(unit_position, unit_mask, unit_counts_map): state.units.position[t], state.units_mask[t], unit_counts_map[t] ), axis=0, + dtype=jnp.int16 ) ) return unit_counts_map @@ -182,11 +183,10 @@ def body_fun(carry, i): vision_power_map_padding:-vision_power_map_padding, vision_power_map_padding:-vision_power_map_padding, ] - # handle nebula tiles vision_power_map = ( vision_power_map - - (state.map_features.tile_type == NEBULA_TILE) + - (state.map_features.tile_type == NEBULA_TILE).astype(jnp.int16) * params.nebula_tile_vision_reduction ) @@ -354,7 +354,7 @@ def sap_unit( [1, -1], [1, 0], [1, 1], - ] + ], dtype=jnp.int16 ) team_sapped_adjacent_positions = ( team_sapped_positions[:, None, :] + adjacent_offsets @@ -376,9 +376,9 @@ def sap_unit( & (other_units_adjacent_sapped_count[:, :, None] > 0), all_units.energy[other_team_ids] - jnp.array( - params.unit_sap_cost + params.unit_sap_cost.astype(jnp.float32) * params.unit_sap_dropoff_factor - * other_units_adjacent_sapped_count[:, :, None], + * other_units_adjacent_sapped_count[:, :, None].astype(jnp.float32), dtype=jnp.int16, ), all_units.energy[other_team_ids], @@ -413,13 +413,12 @@ def sap_unit( # compute energy void fields for all teams and the energy + unit counts unit_aggregate_energy_void_map = jnp.zeros( shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height), - dtype=jnp.float32, + dtype=jnp.int16, ) unit_counts_map = self.compute_unit_counts_map(state, params) - # TODO (stao): this doesn't need to be a float? unit_aggregate_energy_map = jnp.zeros( shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height), - dtype=jnp.float32, + dtype=jnp.int16, ) for t in range(self.fixed_env_params.num_teams): @@ -428,9 +427,9 @@ def scan_body(carry, x): unit_energy, unit_position, unit_mask = x agg_energy_map = agg_energy_map.at[ unit_position[0], unit_position[1] - ].add(unit_energy[0] * unit_mask) + ].add(unit_energy[0] * unit_mask.astype(jnp.int16)) for deltas in [(-1, 0), (1, 0), (0, -1), (0, 1)]: - new_pos = unit_position + jnp.array(deltas) + new_pos = unit_position + jnp.array(deltas, dtype=jnp.int16) in_map = ( (new_pos[0] >= 0) & (new_pos[0] < self.fixed_env_params.map_width) @@ -439,7 +438,7 @@ def scan_body(carry, x): ) agg_energy_void_map = agg_energy_void_map.at[ new_pos[0], new_pos[1] - ].add(unit_energy[0] * unit_mask * in_map) + ].add(unit_energy[0] * unit_mask.astype(jnp.int16) * in_map.astype(jnp.int16)) return (agg_energy_void_map, agg_energy_map), None agg_energy_void_map, agg_energy_map = jax.lax.scan( @@ -498,8 +497,8 @@ def scan_body(carry, x): team_unit_energy = state.units.energy[t] - jnp.floor( jax.vmap( lambda unit_position: params.unit_energy_void_factor - * oppposition_energy_void_map[unit_position[0], unit_position[1]] - / unit_counts_map[t][unit_position[0], unit_position[1]] + * oppposition_energy_void_map[unit_position[0], unit_position[1]].astype(jnp.float32) + / unit_counts_map[t][unit_position[0], unit_position[1]].astype(jnp.float32) )(state.units.position[t])[..., None] ).astype(jnp.int16) state = state.replace( @@ -515,7 +514,7 @@ def update_unit_energy(unit: UnitState, mask): x, y = unit.position energy_gain = ( state.map_features.energy[x, y] - - (state.map_features.tile_type[x, y] == NEBULA_TILE) + - (state.map_features.tile_type[x, y] == NEBULA_TILE).astype(jnp.int16) * params.nebula_tile_energy_reduction ) # if energy gain is less than 0 @@ -665,8 +664,8 @@ def spawn_team_units(state: EnvState): ) new_energy_nodes = jnp.clip( state.energy_nodes + energy_node_deltas, - min=jnp.array([0, 0]), - max=jnp.array([self.fixed_env_params.map_width, self.fixed_env_params.map_height]), + min=jnp.array([0, 0], dtype=jnp.int16), + max=jnp.array([self.fixed_env_params.map_width, self.fixed_env_params.map_height], dtype=jnp.int16), ) new_energy_nodes = jnp.where( state.steps * params.energy_node_drift_speed % 1 == 0, @@ -697,7 +696,7 @@ def team_relic_score(unit_counts_map): -1, ) winner_by_energy = jnp.sum( - state.units.energy[..., 0] * state.units_mask, axis=1 + state.units.energy[..., 0] * state.units_mask.astype(jnp.int16), axis=1 ) winner_by_energy = jnp.where( winner_by_energy.max() > winner_by_energy.min(), @@ -784,19 +783,24 @@ def step( ) info["final_state"] = state_st info["final_observation"] = obs_st + done = terminated | truncated + if self.auto_reset: - done = terminated | truncated obs_re, state_re = self.reset_env(key_reset, params) - # state = jax.tree_map( - # lambda x, y: jax.lax.select(done, x, y), state_re, state_st - # ) - obs = obs_st - # obs = jax.lax.select(done, obs_re, obs_st) + # Use lax.cond to efficiently choose between obs_re and obs_st + obs = jax.lax.cond( + done, + lambda: obs_re, + lambda: obs_st + ) + state = jax.lax.cond( + done, + lambda: state_re, + lambda: state_st + ) else: obs = obs_st state = state_st - # Auto-reset environment based on done - done = terminated | truncated # all agents terminate/truncate at same time terminated_dict = dict() @@ -881,11 +885,11 @@ def update_relic_nodes_mask(relic_nodes_mask, relic_nodes, sensor_mask): relic_nodes_mask=new_relic_nodes_mask, ) obs[f"player_{t}"] = team_obs - return 3 + return obs - @functools.partial(jax.jit, static_argnums=(0, 2)) + @functools.partial(jax.jit, static_argnums=(0, )) def is_terminal(self, state: EnvState, params: EnvParams) -> jnp.ndarray: - """Check whether state is terminal. This occurs when either team wins/loses outright.""" + """Check whether state is terminal. This never occurs. Game is only done when the time limit is reached.""" terminated = jnp.array(False) return terminated @@ -914,3 +918,5 @@ def observation_space(self, params: EnvParams): def state_space(self, params: EnvParams): """State space of the environment.""" return spaces.Discrete(10) + + diff --git a/src/luxai_s3/params.py b/src/luxai_s3/params.py index c079a26..f2e6abe 100644 --- a/src/luxai_s3/params.py +++ b/src/luxai_s3/params.py @@ -1,4 +1,5 @@ from flax import struct +import jax MAP_TYPES = ["dev0", "random"] diff --git a/src/tests/test_gpu.py b/src/tests/test_gpu.py index 7dfd81a..d8845fd 100644 --- a/src/tests/test_gpu.py +++ b/src/tests/test_gpu.py @@ -1,5 +1,6 @@ import time import jax +import jax.numpy as jnp import flax.serialization from luxai_s3.params import EnvParams from luxai_s3.env import LuxAIS3Env @@ -9,13 +10,16 @@ if __name__ == "__main__": import numpy as np - # jax.config.update('jax_numpy_dtype_promotion', 'strict') + jax.config.update('jax_numpy_dtype_promotion', 'strict') np.random.seed(2) # the first env params is not batched and is used to initialize any static / unchaging values # like map size, max units etc. - env = LuxAIS3Env(auto_reset=True, fixed_env_params=EnvParams()) + # note auto_reset=False for speed reasons. If True, the default jax code will attempt to reset each time and discard the reset if its not time to reset + # due to jax branching logic. It should be kept false and instead lax.scan followed by a reset after max episode steps should be used when possible since games + # can't end early. + env = LuxAIS3Env(auto_reset=False, fixed_env_params=EnvParams()) num_envs = 100 seed = 0 rng_key = jax.random.key(seed) @@ -27,7 +31,10 @@ def sample_params(rng_key): randomized_game_params = dict() for k, v in env_params_ranges.items(): rng_key, subkey = jax.random.split(rng_key) - randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v)) + if isinstance(v[0], int): + randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v, dtype=jnp.int16)) + else: + randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v, dtype=jnp.float32)) params = EnvParams(**randomized_game_params) return params @@ -43,11 +50,11 @@ def sample_params(rng_key): env_params ) - print("Benchmarking time") + max_episode_steps = env.fixed_env_params.max_steps_in_match * env.fixed_env_params.match_count_per_episode + print("Benchmarking reset + for loop over jax.step time") stime = time.time() - N = env.fixed_env_params.max_steps_in_match * env.fixed_env_params.match_count_per_episode obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) - for _ in range(N): + for _ in range(max_episode_steps): obs, state, reward, terminated_dict, truncated_dict, info = step_fn( jax.random.split(subkey, num_envs), state, From d65e0878bbe9b319b0d39d8602733f46f5a920c3 Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Sat, 26 Oct 2024 08:22:19 -0700 Subject: [PATCH 09/11] benchmarking --- src/luxai_s3/profiler.py | 140 +++++++++++++++++++++++++++++++++++++ src/tests/benchmark_env.py | 135 +++++++++++++++++++++++++++++++++++ src/tests/test_gpu.py | 65 ----------------- 3 files changed, 275 insertions(+), 65 deletions(-) create mode 100644 src/luxai_s3/profiler.py create mode 100644 src/tests/benchmark_env.py delete mode 100644 src/tests/test_gpu.py diff --git a/src/luxai_s3/profiler.py b/src/luxai_s3/profiler.py new file mode 100644 index 0000000..c42f3ca --- /dev/null +++ b/src/luxai_s3/profiler.py @@ -0,0 +1,140 @@ +from collections import defaultdict +import os +import time +from contextlib import contextmanager +from typing import Literal +import numpy as np + +import psutil +import pynvml +import subprocess as sp +def flatten_dict_keys(d: dict, prefix=""): + """Flatten a dict by expanding its keys recursively.""" + out = dict() + for k, v in d.items(): + if isinstance(v, dict): + out.update(flatten_dict_keys(v, prefix + k + "/")) + else: + out[prefix + k] = v + return out +class Profiler: + """ + A simple class to help profile/benchmark simulator code + """ + + def __init__( + self, output_format: Literal["stdout", "json"], synchronize_torch: bool = True + ) -> None: + self.output_format = output_format + self.synchronize_torch = synchronize_torch + self.stats = defaultdict(list) + # Initialize NVML + pynvml.nvmlInit() + + # Get handle for the first GPU (index 0) + self.handle = pynvml.nvmlDeviceGetHandleByIndex(0) + + # Get the PID of the current process + self.current_pid = os.getpid() + + def log(self, msg): + """log a message to stdout""" + if self.output_format == "stdout": + print(msg) + + def update_csv(self, csv_path: str, data: dict): + """Update a csv file with the given data (a dict representing a unique identifier of the result row) + and stats. If the file does not exist, it will be created. The update will replace an existing row + if the given data matches the data in the row. If there are multiple matches, only the first match + will be replaced and the rest are deleted""" + import pandas as pd + import os + + if os.path.exists(csv_path): + df = pd.read_csv(csv_path) + else: + df = pd.DataFrame() + stats_flat = flatten_dict_keys(self.stats) + cond = None + + for k in stats_flat: + if k not in df: + df[k] = None + for k in data: + if k not in df: + df[k] = None + + mask = df[k].isna() if data[k] is None else df[k] == data[k] + if cond is None: + cond = mask + else: + cond = cond & mask + data_dict = {**data, **stats_flat} + if not cond.any(): + df = pd.concat([df, pd.DataFrame(data_dict, index=[len(df)])]) + else: + # replace the first instance + df.loc[df.loc[cond].index[0]] = data_dict + df.drop(df.loc[cond].index[1:], inplace=True) + # delete other instances + df.to_csv(csv_path, index=False) + + def profile(self, function, name: str, total_steps: int, num_envs: int, trials=1): + print(f"start recording {name} metrics") + process = psutil.Process(os.getpid()) + cpu_mem_use = process.memory_info().rss + gpu_mem_use = self.get_current_process_gpu_memory() + if gpu_mem_use is None: + gpu_mem_use = 0 + + for trial in range(trials): + stime = time.time() + function() + dt = time.time() - stime + # dt: delta time (s) + # fps: frames per second + # psps: parallel steps per second (number of env.step calls per second) + self.stats[name].append(dict( + dt=dt, + fps=total_steps * num_envs / dt, + psps=total_steps / dt, + total_steps=total_steps, + cpu_mem_use=cpu_mem_use, + gpu_mem_use=gpu_mem_use, + )) + # torch.cuda.synchronize() + + def log_stats(self, name: str): + stats = self.stats[name] + more_than_one_trial = len(stats) > 1 + if len(stats) == 0: + return + # average the stats + avg_stats = defaultdict(list) + for data in stats: + for k, v in data.items(): + avg_stats[k].append(v) + stats = {k: {"avg": np.mean(v), "std": np.std(v) if len(v) > 1 else None} for k, v in avg_stats.items()} + self.log( + f"{name} ({len(stats)} trials)" + ) + self.log( + f"{stats['fps']['avg']:0.3f} steps/s, {stats['psps']['avg']:0.3f} parallel steps/s, {stats['total_steps']['avg']} steps in {stats['dt']['avg']:0.3f}s" + ) + if more_than_one_trial: + self.log( + f"{stats['fps']['std']:0.3f} steps/s, {stats['psps']['std']:0.3f} parallel steps/s, {stats['total_steps']['std']} steps in {stats['dt']['std']:0.3f}s" + ) + self.log( + f"{' ' * 4}CPU mem: {avg_stats['cpu_mem_use'] / (1024**2):0.3f} MB, GPU mem: {avg_stats['gpu_mem_use'] / (1024**2):0.3f} MB" + ) + + def get_current_process_gpu_memory(self): + # Get all processes running on the GPU + processes = pynvml.nvmlDeviceGetComputeRunningProcesses(self.handle) + + # Iterate through the processes to find the current process + for process in processes: + if process.pid == self.current_pid: + memory_usage = process.usedGpuMemory + return memory_usage \ No newline at end of file diff --git a/src/tests/benchmark_env.py b/src/tests/benchmark_env.py new file mode 100644 index 0000000..513d253 --- /dev/null +++ b/src/tests/benchmark_env.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass +from functools import partial +import jax +import jax.numpy as jnp +from luxai_s3.params import EnvParams +from luxai_s3.env import LuxAIS3Env +from luxai_s3.params import env_params_ranges +from luxai_s3.profiler import Profiler + +@dataclass +class Args: + num_envs: int = 64 + trials_per_benchmark: int = 5 + +if __name__ == "__main__": + import numpy as np + jax.config.update('jax_numpy_dtype_promotion', 'strict') + + np.random.seed(2) + + # the first env params is not batched and is used to initialize any static / unchaging values + # like map size, max units etc. + # note auto_reset=False for speed reasons. If True, the default jax code will attempt to reset each time and discard the reset if its not time to reset + # due to jax branching logic. It should be kept false and instead lax.scan followed by a reset after max episode steps should be used when possible since games + # can't end early. + env = LuxAIS3Env(auto_reset=False, fixed_env_params=EnvParams()) + num_envs = 100 + seed = 0 + rng_key = jax.random.key(seed) + reset_fn = jax.vmap(env.reset) + step_fn = jax.vmap(env.step) + + # sample random params initially + def sample_params(rng_key): + randomized_game_params = dict() + for k, v in env_params_ranges.items(): + rng_key, subkey = jax.random.split(rng_key) + if isinstance(v[0], int): + randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v, dtype=jnp.int16)) + else: + randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v, dtype=jnp.float32)) + params = EnvParams(**randomized_game_params) + return params + + rng_key, subkey = jax.random.split(rng_key) + env_params = jax.vmap(sample_params)(jax.random.split(subkey, num_envs)) + action_space = env.action_space() # note that this can generate sap actions beyond range atm + sample_action = jax.vmap(action_space.sample) + obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) + obs, state, reward, terminated_dict, truncated_dict, info = step_fn( + jax.random.split(subkey, num_envs), + state, + sample_action(jax.random.split(subkey, num_envs)), + env_params + ) + + max_episode_steps = (env.fixed_env_params.max_steps_in_match + 1) * env.fixed_env_params.match_count_per_episode + rng_key, subkey = jax.random.split(rng_key) + profiler = Profiler(output_format="stdout") + + + def benchmark_reset_for_loop_jax_step(rng_key): + rng_key, subkey = jax.random.split(rng_key) + states = [] + obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) + states.append(state) + for _ in range(max_episode_steps): + rng_key, subkey = jax.random.split(rng_key) + obs, state, reward, terminated_dict, truncated_dict, info = step_fn( + jax.random.split(subkey, num_envs), + state, + sample_action(jax.random.split(subkey, num_envs)), + env_params + ) + jax.block_until_ready(state) + states.append(state) + profiler.profile(partial(benchmark_reset_for_loop_jax_step, rng_key), "reset + for loop jax.step", total_steps=max_episode_steps, num_envs=num_envs, trials=5) + profiler.log_stats("reset + for loop jax.step") + + + def run_episode(rng_key, state, env_params): + def take_step(carry, _): + rng_key, state = carry + rng_key, subkey = jax.random.split(rng_key) + obs, state, reward, terminated_dict, truncated_dict, info = step_fn( + jax.random.split(subkey, num_envs), + state, + sample_action(jax.random.split(subkey, num_envs)), + env_params + ) + return (rng_key, state), (obs, state, reward, terminated_dict, truncated_dict, info) + _, (obs, state, reward, terminated_dict, truncated_dict, info) = jax.lax.scan(take_step, (rng_key, state), length=max_episode_steps, unroll=1) + return obs, state, reward, terminated_dict, truncated_dict, info + # compile the scan + print("Compiling run_episode") + run_episode = jax.jit(run_episode) + run_episode(subkey, state, env_params) + print("Compiling run_episode done") + + def benchmark_reset_jax_lax_scan_jax_step(rng_key): + rng_key, subkey = jax.random.split(rng_key) + obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) + rng_key, subkey = jax.random.split(rng_key) + # obs now has shape (max_episode_steps, num_envs, ...) + obs, state, reward, terminated_dict, truncated_dict, info = run_episode(subkey, state, env_params) + jax.block_until_ready(state) + profiler.profile(partial(benchmark_reset_jax_lax_scan_jax_step, rng_key), "reset + jax.lax.scan(jax.step)", total_steps=max_episode_steps, num_envs=num_envs, trials=5) + profiler.log_stats("reset + jax.lax.scan(jax.step)") + + def run_episode_and_reset(rng_key, env_params): + rng_key, subkey = jax.random.split(rng_key) + obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) + def take_step(carry, _): + rng_key, state = carry + rng_key, subkey = jax.random.split(rng_key) + obs, state, reward, terminated_dict, truncated_dict, info = step_fn( + jax.random.split(subkey, num_envs), + state, + sample_action(jax.random.split(subkey, num_envs)), + env_params + ) + return (rng_key, state), (obs, state, reward, terminated_dict, truncated_dict, info) + _, (obs, state, reward, terminated_dict, truncated_dict, info) = jax.lax.scan(take_step, (rng_key, state), length=max_episode_steps) + return obs, state, reward, terminated_dict, truncated_dict, info + # compile the scan + print("Compiling run_episode_and_reset") + run_episode_and_reset = jax.jit(run_episode_and_reset) + run_episode_and_reset(subkey, env_params) + print("Compiling run_episode_and_reset done") + def benchmark_jit_reset_lax_scan_jax_step(rng_key): + rng_key, subkey = jax.random.split(rng_key) + obs, state, reward, terminated_dict, truncated_dict, info = run_episode_and_reset(subkey, env_params) + jax.block_until_ready(state) + profiler.profile(partial(benchmark_jit_reset_lax_scan_jax_step, rng_key), "jit(reset + jax.lax.scan(jax.step))", total_steps=max_episode_steps, num_envs=num_envs, trials=5) + profiler.log_stats("jit(reset + jax.lax.scan(jax.step))") diff --git a/src/tests/test_gpu.py b/src/tests/test_gpu.py deleted file mode 100644 index d8845fd..0000000 --- a/src/tests/test_gpu.py +++ /dev/null @@ -1,65 +0,0 @@ -import time -import jax -import jax.numpy as jnp -import flax.serialization -from luxai_s3.params import EnvParams -from luxai_s3.env import LuxAIS3Env -from luxai_s3.params import env_params_ranges -from luxai_s3.state import gen_map -from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode - -if __name__ == "__main__": - import numpy as np - jax.config.update('jax_numpy_dtype_promotion', 'strict') - - np.random.seed(2) - - # the first env params is not batched and is used to initialize any static / unchaging values - # like map size, max units etc. - # note auto_reset=False for speed reasons. If True, the default jax code will attempt to reset each time and discard the reset if its not time to reset - # due to jax branching logic. It should be kept false and instead lax.scan followed by a reset after max episode steps should be used when possible since games - # can't end early. - env = LuxAIS3Env(auto_reset=False, fixed_env_params=EnvParams()) - num_envs = 100 - seed = 0 - rng_key = jax.random.key(seed) - reset_fn = jax.vmap(env.reset) - step_fn = jax.vmap(env.step) - - # sample random params initially - def sample_params(rng_key): - randomized_game_params = dict() - for k, v in env_params_ranges.items(): - rng_key, subkey = jax.random.split(rng_key) - if isinstance(v[0], int): - randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v, dtype=jnp.int16)) - else: - randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v, dtype=jnp.float32)) - params = EnvParams(**randomized_game_params) - return params - - rng_key, subkey = jax.random.split(rng_key) - env_params = jax.vmap(sample_params)(jax.random.split(subkey, num_envs)) - action_space = env.action_space() # note that this can generate sap actions beyond range atm - sample_action = jax.vmap(action_space.sample) - obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) - obs, state, reward, terminated_dict, truncated_dict, info = step_fn( - jax.random.split(subkey, num_envs), - state, - sample_action(jax.random.split(subkey, num_envs)), - env_params - ) - - max_episode_steps = env.fixed_env_params.max_steps_in_match * env.fixed_env_params.match_count_per_episode - print("Benchmarking reset + for loop over jax.step time") - stime = time.time() - obs, state = reset_fn(jax.random.split(subkey, num_envs), env_params) - for _ in range(max_episode_steps): - obs, state, reward, terminated_dict, truncated_dict, info = step_fn( - jax.random.split(subkey, num_envs), - state, - sample_action(jax.random.split(subkey, num_envs)), - env_params - ) - etime = time.time() - print(f"FPS: {N * num_envs / (etime - stime):0.3f}. {N / (etime - stime):0.3f} parallel steps/s") From e8a814ef3294494c3f160c6e6f2c5909f4550544 Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Sat, 26 Oct 2024 08:28:16 -0700 Subject: [PATCH 10/11] fix --- src/luxai_s3/profiler.py | 8 ++++---- src/tests/benchmark_env.py | 29 +++++++++++++++++------------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/luxai_s3/profiler.py b/src/luxai_s3/profiler.py index c42f3ca..c95c408 100644 --- a/src/luxai_s3/profiler.py +++ b/src/luxai_s3/profiler.py @@ -116,17 +116,17 @@ def log_stats(self, name: str): avg_stats[k].append(v) stats = {k: {"avg": np.mean(v), "std": np.std(v) if len(v) > 1 else None} for k, v in avg_stats.items()} self.log( - f"{name} ({len(stats)} trials)" + f"{name} ({len(self.stats[name])} trials)" ) self.log( - f"{stats['fps']['avg']:0.3f} steps/s, {stats['psps']['avg']:0.3f} parallel steps/s, {stats['total_steps']['avg']} steps in {stats['dt']['avg']:0.3f}s" + f"AVG: {stats['fps']['avg']:0.3f} steps/s, {stats['psps']['avg']:0.3f} parallel steps/s, {stats['total_steps']['avg']} steps in {stats['dt']['avg']:0.3f}s" ) if more_than_one_trial: self.log( - f"{stats['fps']['std']:0.3f} steps/s, {stats['psps']['std']:0.3f} parallel steps/s, {stats['total_steps']['std']} steps in {stats['dt']['std']:0.3f}s" + f"STD: {stats['fps']['std']:0.3f} steps/s, {stats['psps']['std']:0.3f} parallel steps/s, {stats['total_steps']['std']} steps in {stats['dt']['std']:0.3f}s" ) self.log( - f"{' ' * 4}CPU mem: {avg_stats['cpu_mem_use'] / (1024**2):0.3f} MB, GPU mem: {avg_stats['gpu_mem_use'] / (1024**2):0.3f} MB" + f"{' ' * 4}CPU mem: {stats['cpu_mem_use']['avg'] / (1024**2):0.3f} MB, GPU mem: {stats['gpu_mem_use']['avg'] / (1024**2):0.3f} MB" ) def get_current_process_gpu_memory(self): diff --git a/src/tests/benchmark_env.py b/src/tests/benchmark_env.py index 513d253..35e609c 100644 --- a/src/tests/benchmark_env.py +++ b/src/tests/benchmark_env.py @@ -1,7 +1,9 @@ from dataclasses import dataclass from functools import partial +from typing import Annotated import jax import jax.numpy as jnp +import tyro from luxai_s3.params import EnvParams from luxai_s3.env import LuxAIS3Env from luxai_s3.params import env_params_ranges @@ -9,14 +11,17 @@ @dataclass class Args: - num_envs: int = 64 - trials_per_benchmark: int = 5 + num_envs: Annotated[int, tyro.conf.arg(aliases=["-n"])] = 64 + trials_per_benchmark: Annotated[int, tyro.conf.arg(aliases=["-t"])] = 5 + verbose: Annotated[int, tyro.conf.arg(aliases=["-v"])] = 0 + seed: int = 0 if __name__ == "__main__": import numpy as np jax.config.update('jax_numpy_dtype_promotion', 'strict') + args = tyro.cli(Args) - np.random.seed(2) + np.random.seed(args.seed) # the first env params is not batched and is used to initialize any static / unchaging values # like map size, max units etc. @@ -24,8 +29,8 @@ class Args: # due to jax branching logic. It should be kept false and instead lax.scan followed by a reset after max episode steps should be used when possible since games # can't end early. env = LuxAIS3Env(auto_reset=False, fixed_env_params=EnvParams()) - num_envs = 100 - seed = 0 + num_envs = args.num_envs + seed = args.seed rng_key = jax.random.key(seed) reset_fn = jax.vmap(env.reset) step_fn = jax.vmap(env.step) @@ -74,7 +79,7 @@ def benchmark_reset_for_loop_jax_step(rng_key): ) jax.block_until_ready(state) states.append(state) - profiler.profile(partial(benchmark_reset_for_loop_jax_step, rng_key), "reset + for loop jax.step", total_steps=max_episode_steps, num_envs=num_envs, trials=5) + profiler.profile(partial(benchmark_reset_for_loop_jax_step, rng_key), "reset + for loop jax.step", total_steps=max_episode_steps, num_envs=num_envs, trials=args.trials_per_benchmark) profiler.log_stats("reset + for loop jax.step") @@ -92,10 +97,10 @@ def take_step(carry, _): _, (obs, state, reward, terminated_dict, truncated_dict, info) = jax.lax.scan(take_step, (rng_key, state), length=max_episode_steps, unroll=1) return obs, state, reward, terminated_dict, truncated_dict, info # compile the scan - print("Compiling run_episode") + if args.verbose: print("Compiling run_episode") run_episode = jax.jit(run_episode) run_episode(subkey, state, env_params) - print("Compiling run_episode done") + if args.verbose: print("Compiling run_episode done") def benchmark_reset_jax_lax_scan_jax_step(rng_key): rng_key, subkey = jax.random.split(rng_key) @@ -104,7 +109,7 @@ def benchmark_reset_jax_lax_scan_jax_step(rng_key): # obs now has shape (max_episode_steps, num_envs, ...) obs, state, reward, terminated_dict, truncated_dict, info = run_episode(subkey, state, env_params) jax.block_until_ready(state) - profiler.profile(partial(benchmark_reset_jax_lax_scan_jax_step, rng_key), "reset + jax.lax.scan(jax.step)", total_steps=max_episode_steps, num_envs=num_envs, trials=5) + profiler.profile(partial(benchmark_reset_jax_lax_scan_jax_step, rng_key), "reset + jax.lax.scan(jax.step)", total_steps=max_episode_steps, num_envs=num_envs, trials=args.trials_per_benchmark) profiler.log_stats("reset + jax.lax.scan(jax.step)") def run_episode_and_reset(rng_key, env_params): @@ -123,13 +128,13 @@ def take_step(carry, _): _, (obs, state, reward, terminated_dict, truncated_dict, info) = jax.lax.scan(take_step, (rng_key, state), length=max_episode_steps) return obs, state, reward, terminated_dict, truncated_dict, info # compile the scan - print("Compiling run_episode_and_reset") + if args.verbose: print("Compiling run_episode_and_reset") run_episode_and_reset = jax.jit(run_episode_and_reset) run_episode_and_reset(subkey, env_params) - print("Compiling run_episode_and_reset done") + if args.verbose: print("Compiling run_episode_and_reset done") def benchmark_jit_reset_lax_scan_jax_step(rng_key): rng_key, subkey = jax.random.split(rng_key) obs, state, reward, terminated_dict, truncated_dict, info = run_episode_and_reset(subkey, env_params) jax.block_until_ready(state) - profiler.profile(partial(benchmark_jit_reset_lax_scan_jax_step, rng_key), "jit(reset + jax.lax.scan(jax.step))", total_steps=max_episode_steps, num_envs=num_envs, trials=5) + profiler.profile(partial(benchmark_jit_reset_lax_scan_jax_step, rng_key), "jit(reset + jax.lax.scan(jax.step))", total_steps=max_episode_steps, num_envs=num_envs, trials=args.trials_per_benchmark) profiler.log_stats("jit(reset + jax.lax.scan(jax.step))") From 4393e17ea81e69e92bcf506cddd2804530ccb94a Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Sat, 26 Oct 2024 08:34:46 -0700 Subject: [PATCH 11/11] update readme --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 4d1b12a..4133674 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,19 @@ luxai-s3 path/to/bot/main.py path/to/bot/main.py --output replay.json Then upload the replay.json to the online visualizer here: https://s3vis.lux-ai.org/ (a link on the lux-ai.org website will be up soon) +## GPU Acceleration + +Jax will already provide some decent CPU based parallelization for batch running the environment. A GPU or TPU however can increase the environment throughput much more however. + +To install jax with GPU/TPU support, you can follow the instructions [here](https://jax.readthedocs.io/en/latest/installation.html). + +To benchmark your throughput speeds, you can run + +``` +pip install pynvml psutil +python Lux-Design-S3/src/tests/benchmark_env.py -n 16384 -t 5 # 16384 envs, 5 trials each test +``` + ### Starter Kits Each supported programming language/solution type has its own starter kit.