Skip to content

Commit

Permalink
Added dump_frequency in script train_agent.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AlejandroCN7 committed Sep 14, 2023
1 parent 9a1968b commit 8bcfbe1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
13 changes: 8 additions & 5 deletions scripts/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@
# env for evaluation if is enabled
eval_env = None
if conf.get('evaluation'):
eval_name = conf['evaluation'].get('name', env.name + '-EVAL')
eval_name = conf['evaluation'].get(
'name', env.name + '-EVAL')
env_params.update({'env_name': eval_name})
eval_env = gym.make(
conf['environment'],
Expand Down Expand Up @@ -224,7 +225,8 @@
# ---------------------------------------------------------------------------- #
# Calculating total training timesteps based on number of episodes #
# ---------------------------------------------------------------------------- #
timesteps = conf['episodes'] * (env.timestep_per_episode - 1)
timesteps = conf['episodes'] * \
(env.timestep_per_episode - 1)

# ---------------------------------------------------------------------------- #
# CALLBACKS #
Expand All @@ -239,8 +241,8 @@
'/best_model/',
log_path=eval_env.experiment_path +
'/best_model/',
eval_freq=(eval_env.timestep_per_episode - 1) *
conf['evaluation']['eval_freq'] - 1,
eval_freq=(eval_env.timestep_per_episode) *
conf['evaluation']['eval_freq'],
deterministic=True,
render=False,
n_eval_episodes=conf['evaluation']['eval_length'])
Expand All @@ -258,7 +260,8 @@
WandBOutputFormat()])
model.set_logger(logger)
# Append callback
log_callback = LoggerCallback()
dump_frequency = conf['wandb'].get('dump_frequency', 100)
log_callback = LoggerCallback(dump_frequency=dump_frequency)
callbacks.append(log_callback)

callback = CallbackList(callbacks)
Expand Down
3 changes: 2 additions & 1 deletion scripts/train_agent_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"entity": "alex_ugr"
},
"artifact_name": "experiment1",
"artifact_type": "training"
"artifact_type": "training",
"dump_frequency": 500
}
}

0 comments on commit 8bcfbe1

Please sign in to comment.