diff --git a/config.yaml b/config.yaml index b268dd7..d8353c2 100644 --- a/config.yaml +++ b/config.yaml @@ -1,5 +1,6 @@ defaults: - override hydra/launcher: joblib +algorithm: kaqn env_id: CartPole-v1 batch_size: 256 n_episodes: 500 diff --git a/experiment.sh b/experiment.sh index 270d4ba..71b35e2 100644 --- a/experiment.sh +++ b/experiment.sh @@ -1,5 +1,5 @@ #!/bin/sh echo "Starting experiments with MLP..." -python kaqn.py --multirun seed="range(32)" method=MLP width=32 +python run_experiment.py --multirun seed="range(32)" method=MLP width=32 echo "Starting experiments with KAN..." -python kaqn.py --multirun seed="range(32)" \ No newline at end of file +python run_experiment.py --multirun seed="range(32)" \ No newline at end of file diff --git a/kaqn.py b/kaqn.py index 56f64e9..bce5d60 100644 --- a/kaqn.py +++ b/kaqn.py @@ -111,7 +111,6 @@ def set_all_seeds(seed): torch.use_deterministic_algorithms(True) -@hydra.main(config_path=".", config_name="config", version_base=None) def main(config: DictConfig): set_all_seeds(config.seed) env = gym.make(config.env_id) @@ -222,6 +221,3 @@ def main(config: DictConfig): if episode % config.target_update_freq == 0: target_network.load_state_dict(q_network.state_dict()) - -if __name__ == "__main__": - main() diff --git a/run_experiment.py b/run_experiment.py new file mode 100644 index 0000000..6a8b7cd --- /dev/null +++ b/run_experiment.py @@ -0,0 +1,32 @@ +""" +Main driver to dispatch experiments for different algorithms. + +Algorithm name should be specified in a config file as, e.g. +``` +algorithm: kaqn +``` +and supported in the `imports` dispatch dict below. + +Then, this file can be invoked simply as +``` +python run_experiment.py +``` +""" + +import hydra +from omegaconf import DictConfig + +# Dispatch algorithm name to a subprocess call. We probably don't want to make +# this arbitrary (security reasons) +imports = { + "kaqn": exec('from kaqn import main as algo'), +} + + +@hydra.main(config_path=".", config_name="config", version_base=None) +def main(config: DictConfig): + imports[config.algorithm] + algo(config) + +if __name__ == "__main__": + main() \ No newline at end of file