diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index 326bd13..e4820a9 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -179,6 +179,10 @@ def launch_agent( debug: bool = False, ): def _fn(): + import torch + + torch.compiler.reset() + agent_config: DictConfig = OmegaConf.load(os.environ["WANDB_SWEEP_PARAM_PATH"]).x.value agent_config = update_base_config(base_config, agent_config) train.train(config=agent_config, debug=debug, track=True)