Skip to content

Commit

Permalink
(v3.5.1) - WandBLogger compability with previous WandB sessions (#428)
Browse files Browse the repository at this point in the history
* Updated tool version from 3.5.0 to 3.5.1

* WandbLogger: Added compatibility with wandb created sessions previously

* WandBLogger: Episode summary will not be dump with an empty episode

* Scripts (train_agent.py): Fixed default experiment name in order to be compatible with artifact save supported names

* Callbacks (LoggerEvalCallback): Fixed case when several evaluation episodes are set up.

* Documentation: WandBLogger section updated with this new feature.
  • Loading branch information
AlejandroCN7 authored Aug 23, 2024
1 parent 918df94 commit 7e6706c
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 29 deletions.
4 changes: 4 additions & 0 deletions docs/source/pages/wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ It is useful for real-time training process monitoring and is combinable with St
The initialization allows definition of the project, entity, run groups, tags, and whether code or outputs are saved as platform
artifacts, as well as dump frequency, excluded info keys, and excluded summary metric keys.

This wrapper can be used with a pre-existing WandB session, without the need to specify the entity or project
(which, if provided, will be ignored), such as when using sweeps. It still allows specifying other parameters during construction,
maintaining full functionality of the wrapper. If there is no pre-existing WandB session, the entity and project fields are required.

.. important:: A Weights and Biases account is required to use this wrapper, with an environment variable containing the API key for login.
For more information, visit `Weights and Biases <https://wandb.ai/site>`__.

Expand Down
2 changes: 1 addition & 1 deletion scripts/train/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def process_algorithm_parameters(alg_params: dict):
# ---------------------------------------------------------------------------- #
# Register run name #
# ---------------------------------------------------------------------------- #
experiment_date = datetime.today().strftime('%Y-%m-%d_%H:%M')
experiment_date = datetime.today().strftime('%Y-%m-%d_%H-%M')
experiment_name = conf['algorithm']['name'] + '-' + conf['environment'] + \
'-episodes-' + str(conf['episodes'])
if conf.get('seed'):
Expand Down
9 changes: 1 addition & 8 deletions sinergym/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,7 @@ def _evaluate_policy(self) -> Dict[str, List[Any]]:

result = {key: [] for key in self.evaluation_columns}

for i in range(self.n_eval_episodes):
# If is not the first episode, save last episode metrics
if i > 0:
summary = self.eval_env.get_wrapper_attr("get_episode_summary")(
self.eval_env.get_wrapper_attr("episode"))
# Append values to result dictionary
for key in result.keys():
result[key].append(summary[key])
for _ in range(self.n_eval_episodes):

obs, _ = self.eval_env.reset()
state = None
Expand Down
55 changes: 36 additions & 19 deletions sinergym/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,10 +1204,11 @@ class WandBLogger(gym.Wrapper):

def __init__(self,
env: Env,
entity: str,
project_name: str,
entity: Optional[str] = None,
project_name: Optional[str] = None,
run_name: Optional[str] = None,
group: Optional[str] = None,
job_type: Optional[str] = None,
tags: Optional[List[str]] = None,
save_code: bool = False,
dump_frequency: int = 1000,
Expand All @@ -1228,9 +1229,11 @@ def __init__(self,
Args:
env (Env): Original Sinergym environment.
entity (str): The entity to which the project belongs.
project_name (str): The project name.
entity (Optional[str]): The entity to which the project belongs. If you are using a previous wandb run, it is not necessary to specify it. Defaults to None.
project_name (Optional[str]): The project name. If you are using a previous wandb run, it is not necessary to specify it. Defaults to None.
run_name (Optional[str]): The name of the run. Defaults to None (Sinergym env name + wandb unique identifier).
group (Optional[str]): The name of the group to which the run belongs. Defaults to None.
job_type (Optional[str]): The type of job. Defaults to None.
tags (Optional[List[str]]): List of tags for the run. Defaults to None.
save_code (bool): Whether to save the code in the run. Defaults to False.
dump_frequency (int): Frequency to dump log in platform. Defaults to 1000.
Expand All @@ -1257,13 +1260,26 @@ def __init__(self,
'name') + '_' + wandb.util.generate_id()

# Init WandB session
self.wandb_run = wandb.init(entity=entity,
project=project_name,
name=run_name,
group=group,
tags=tags,
save_code=save_code,
reinit=False)
# If there is no active run and entity and project has been specified,
# initialize a new one using the parameters
if wandb.run is None and (
entity is not None and project_name is not None):
self.wandb_run = wandb.init(entity=entity,
project=project_name,
name=run_name,
group=group,
job_type=job_type,
tags=tags,
save_code=save_code,
reinit=False)
# If there is an active run
elif wandb.run is not None:
# Use the active run
self.wandb_run = wandb.run
else:
self.logger.error(
'Error initializing WandB run, if project and entity are not specified, it should be a previous active wandb run, but it has not been found.')
raise RuntimeError

# Wandb finish with env.close flag
self.wandb_finish = True
Expand Down Expand Up @@ -1404,14 +1420,15 @@ def wandb_log(self) -> None:
def wandb_log_summary(self) -> None:
"""Log episode summary in WandB platform.
"""
# Get information from logger of LoggerWrapper
episode_summary = self.get_wrapper_attr(
'get_episode_summary')()
# Deleting excluded keys
episode_summary = {key: value for key, value in episode_summary.items(
) if key not in self.get_wrapper_attr('excluded_episode_summary_keys')}
# Log summary data in WandB
self._log_data({'episode_summaries': episode_summary})
if len(self.get_wrapper_attr('data_logger').rewards) > 0:
# Get information from logger of LoggerWrapper
episode_summary = self.get_wrapper_attr(
'get_episode_summary')()
# Deleting excluded keys
episode_summary = {key: value for key, value in episode_summary.items(
) if key not in self.get_wrapper_attr('excluded_episode_summary_keys')}
# Log summary data in WandB
self._log_data({'episode_summaries': episode_summary})

def save_artifact(self) -> None:
"""Save sinergym output as artifact in WandB platform.
Expand Down
2 changes: 1 addition & 1 deletion sinergym/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.5.0
3.5.1

0 comments on commit 7e6706c

Please sign in to comment.