Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add main experiment driver and specify kaqn as modifiable algo #11

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- override hydra/launcher: joblib
algorithm: kaqn
env_id: CartPole-v1
batch_size: 256
n_episodes: 500
Expand Down
4 changes: 2 additions & 2 deletions experiment.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
echo "Starting experiments with MLP..."
python kaqn.py --multirun seed="range(32)" method=MLP width=32
python run_experiment.py --multirun seed="range(32)" method=MLP width=32
echo "Starting experiments with KAN..."
python kaqn.py --multirun seed="range(32)"
python run_experiment.py --multirun seed="range(32)"
4 changes: 0 additions & 4 deletions kaqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def set_all_seeds(seed):
torch.use_deterministic_algorithms(True)


@hydra.main(config_path=".", config_name="config", version_base=None)
def main(config: DictConfig):
set_all_seeds(config.seed)
env = gym.make(config.env_id)
Expand Down Expand Up @@ -222,6 +221,3 @@ def main(config: DictConfig):
if episode % config.target_update_freq == 0:
target_network.load_state_dict(q_network.state_dict())


if __name__ == "__main__":
main()
32 changes: 32 additions & 0 deletions run_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Main driver to dispatch experiments for different algorithms.

Algorithm name should be specified in a config file as, e.g.
```
algorithm: kaqn
```
and supported in the `imports` dispatch dict below.

Then, this file can be invoked simply as
```
python run_experiment.py
```
"""

import hydra
from omegaconf import DictConfig

# Dispatch algorithm name to a subprocess call. We probably don't want to make
# this arbitrary (security reasons)
imports = {
"kaqn": exec('from kaqn import main as algo'),
}


@hydra.main(config_path=".", config_name="config", version_base=None)
def main(config: DictConfig):
imports[config.algorithm]
algo(config)

if __name__ == "__main__":
main()