Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batching to enable randomized game params + vmapping #22

Merged
merged 11 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ luxai-s3 path/to/bot/main.py path/to/bot/main.py --output replay.json

Then upload the replay.json to the online visualizer here: https://s3vis.lux-ai.org/ (a link on the lux-ai.org website will be up soon)

## GPU Acceleration

Jax will already provide some decent CPU based parallelization for batch running the environment. A GPU or TPU however can increase the environment throughput much more however.

To install jax with GPU/TPU support, you can follow the instructions [here](https://jax.readthedocs.io/en/latest/installation.html).

To benchmark your throughput speeds, you can run

```
pip install pynvml psutil
python Lux-Design-S3/src/tests/benchmark_env.py -n 16384 -t 5 # 16384 envs, 5 trials each test
```

### Starter Kits

Each supported programming language/solution type has its own starter kit.
Expand Down
16 changes: 11 additions & 5 deletions src/luxai_runner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import sys
from pathlib import Path
from typing import Dict, List
from typing import Annotated, Dict, List

import numpy as np
from luxai_runner.bot import Bot
Expand All @@ -15,23 +15,25 @@
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ReplayConfig:
save_format: str = "json"
"""Format of the replay file, can be either "html" or "json". HTML replays are easier to visualize, JSON replays are easier to analyze programmatically. Defaults to the extension of the path passed to --output, or "json" if there is no extension or it is invalid."""
compressed_obs: bool = True
"""Whether to save compressed observations or not. Compressed observations do not contain the full observation at each step. In particular, the map information is stored as the first observation, subsequent observations only store the changes that happened."""


@dataclass
class Args:
players: tyro.conf.Positional[List[str]]
"""Paths to player modules. If --tournament is passed as well, you can also pass a folder and we will look through all sub-folders for valid agents with main.py files (only works for python agents at the moment)."""
len: Optional[int] = 1000
"""Max episode length"""
output: Optional[str] = None
output: Annotated[Optional[str], tyro.conf.arg(aliases=["-o"])] = None
"""Where to output replays. Default is none and no replay is generated"""
replay: ReplayConfig = field(default_factory=lambda : ReplayConfig())
replay: ReplayConfig = field(default_factory=lambda: ReplayConfig())

verbose: int = 2
"""Verbose Level (0 = silent, 1 = (game-ending errors, debug logs from agents), 2 = warnings (non-game ending invalid actions), 3 = info (system info, unit collisions) )"""
seed: Optional[int] = None
Expand All @@ -47,6 +49,7 @@ class Args:
# skip_validate_action_space: bool = False
# """Set this for a small performance increase. Note that turning this on means the engine assumes your submitted actions are valid. If your actions are not well formatted there could be errors"""


def main():
args = tyro.cli(Args)

Expand All @@ -64,7 +67,9 @@ def main():
np.random.seed(args.seed)
cfg = EpisodeConfig(
players=args.players,
env_cls=lambda **kwargs: RecordEpisode(LuxAIS3GymEnv(numpy_output=True), save_on_close=False),
env_cls=lambda **kwargs: RecordEpisode(
LuxAIS3GymEnv(numpy_output=True), save_on_close=False
),
seed=args.seed,
env_cfg=dict(
# verbose=args.verbose,
Expand Down Expand Up @@ -121,5 +126,6 @@ def main():
print("Time Elapsed: ", etime - stime)
print("Rewards: ", results.rewards)


if __name__ == "__main__":
main()
24 changes: 16 additions & 8 deletions src/luxai_runner/episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ class EpisodeConfig:
save_replay_path: Optional[str] = None
replay_options: ReplayConfig = field(default_factory=ReplayConfig)


@dataclass
class EpisodeResults:
rewards: dict[str, float]


class Episode:
def __init__(self, cfg: EpisodeConfig) -> None:
self.cfg = cfg
Expand Down Expand Up @@ -135,11 +137,11 @@ async def run(self):
)

# if save_replay:
# replay = dict(observations=[], actions=[], dones=[], rewards=[])
# if self.cfg.replay_options.compressed_obs:
# replay["observations"].append(state_obs)
# else:
# replay["observations"].append(self.env.state.get_obs())
# replay = dict(observations=[], actions=[], dones=[], rewards=[])
# if self.cfg.replay_options.compressed_obs:
# replay["observations"].append(state_obs)
# else:
# replay["observations"].append(self.env.state.get_obs())

i = 0
while not game_done:
Expand All @@ -149,7 +151,10 @@ async def run(self):
action_coros = []
for player in players.values():
action = player.step(
obs=obs[player.agent], step=i, reward=rewards[player.agent], info=infos[player.agent]
obs=obs[player.agent],
step=i,
reward=rewards[player.agent],
info=infos[player.agent],
)
action_coros += [action]
agent_ids += [player.agent]
Expand All @@ -166,7 +171,9 @@ async def run(self):
else:
print(f"{agent_id} sent a invalid action {action}")
actions[agent_id] = None
new_state_obs, rewards, terminations, truncations, infos = self.env.step(actions)
new_state_obs, rewards, terminations, truncations, infos = self.env.step(
actions
)
i += 1
# TODO (stao): hard code to avoid using jax structs in the infos and sending those.
infos = dict(player_0=dict(), player_1=dict())
Expand Down Expand Up @@ -195,5 +202,6 @@ async def run(self):
await player.proc.cleanup()

return EpisodeResults(rewards=rewards)

def close(self):
pass
pass
6 changes: 5 additions & 1 deletion src/luxai_runner/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ async def start(self):
base_file_path = os.path.basename(self.file_path)
if self.is_binary:
self._agent_process = await asyncio.create_subprocess_exec(
f"{cwd}\{base_file_path}" if sys.platform.startswith('win') else f"./{base_file_path}",
(
f"{cwd}\{base_file_path}"
if sys.platform.startswith("win")
else f"./{base_file_path}"
),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
Expand Down
Loading