diff --git a/scripts/train_agent.py b/scripts/train_agent.py index 2400e50ab3..c4661a03b7 100644 --- a/scripts/train_agent.py +++ b/scripts/train_agent.py @@ -106,7 +106,8 @@ # env for evaluation if is enabled eval_env = None if conf.get('evaluation'): - eval_name = conf['evaluation'].get('name', env.name + '-EVAL') + eval_name = conf['evaluation'].get( + 'name', env.name + '-EVAL') env_params.update({'env_name': eval_name}) eval_env = gym.make( conf['environment'], @@ -224,7 +225,8 @@ # ---------------------------------------------------------------------------- # # Calculating total training timesteps based on number of episodes # # ---------------------------------------------------------------------------- # - timesteps = conf['episodes'] * (env.timestep_per_episode - 1) + timesteps = conf['episodes'] * \ + (env.timestep_per_episode - 1) # ---------------------------------------------------------------------------- # # CALLBACKS # @@ -239,8 +241,8 @@ '/best_model/', log_path=eval_env.experiment_path + '/best_model/', - eval_freq=(eval_env.timestep_per_episode - 1) * - conf['evaluation']['eval_freq'] - 1, + eval_freq=(eval_env.timestep_per_episode) * + conf['evaluation']['eval_freq'], deterministic=True, render=False, n_eval_episodes=conf['evaluation']['eval_length']) @@ -258,7 +260,8 @@ WandBOutputFormat()]) model.set_logger(logger) # Append callback - log_callback = LoggerCallback() + dump_frequency = conf['wandb'].get('dump_frequency', 100) + log_callback = LoggerCallback(dump_frequency=dump_frequency) callbacks.append(log_callback) callback = CallbackList(callbacks) diff --git a/scripts/train_agent_example.json b/scripts/train_agent_example.json index 29feac2ef9..8bbf858ca9 100644 --- a/scripts/train_agent_example.json +++ b/scripts/train_agent_example.json @@ -43,6 +43,7 @@ "entity": "alex_ugr" }, "artifact_name": "experiment1", - "artifact_type": "training" + "artifact_type": "training", + "dump_frequency": 500 } } \ No newline at end of file