diff --git a/docs/source/pages/wrappers.rst b/docs/source/pages/wrappers.rst index 634d4fb73..d54728ec9 100644 --- a/docs/source/pages/wrappers.rst +++ b/docs/source/pages/wrappers.rst @@ -199,6 +199,10 @@ It is useful for real-time training process monitoring and is combinable with St The initialization allows definition of the project, entity, run groups, tags, and whether code or outputs are saved as platform artifacts, as well as dump frequency, excluded info keys, and excluded summary metric keys. +This wrapper can be used with a pre-existing WandB session, without the need to specify the entity or project +(which, if provided, will be ignored), such as when using sweeps. It still allows specifying other parameters during construction, +maintaining full functionality of the wrapper. If there is no pre-existing WandB session, the entity and project fields are required. + .. important:: A Weights and Biases account is required to use this wrapper, with an environment variable containing the API key for login. For more information, visit `Weights and Biases `__. diff --git a/scripts/train/train_agent.py b/scripts/train/train_agent.py index 6c216dc68..742cccc93 100644 --- a/scripts/train/train_agent.py +++ b/scripts/train/train_agent.py @@ -110,7 +110,7 @@ def process_algorithm_parameters(alg_params: dict): # ---------------------------------------------------------------------------- # # Register run name # # ---------------------------------------------------------------------------- # - experiment_date = datetime.today().strftime('%Y-%m-%d_%H:%M') + experiment_date = datetime.today().strftime('%Y-%m-%d_%H-%M') experiment_name = conf['algorithm']['name'] + '-' + conf['environment'] + \ '-episodes-' + str(conf['episodes']) if conf.get('seed'): diff --git a/sinergym/utils/callbacks.py b/sinergym/utils/callbacks.py index 4649cffbe..947a06275 100644 --- a/sinergym/utils/callbacks.py +++ b/sinergym/utils/callbacks.py @@ -175,14 +175,7 @@ def _evaluate_policy(self) -> Dict[str, List[Any]]: result = {key: [] for key in self.evaluation_columns} - for i in range(self.n_eval_episodes): - # If is not the first episode, save last episode metrics - if i > 0: - summary = self.eval_env.get_wrapper_attr("get_episode_summary")( - self.eval_env.get_wrapper_attr("episode")) - # Append values to result dictionary - for key in result.keys(): - result[key].append(summary[key]) + for _ in range(self.n_eval_episodes): obs, _ = self.eval_env.reset() state = None diff --git a/sinergym/utils/wrappers.py b/sinergym/utils/wrappers.py index 3f0583255..64c0bbd67 100644 --- a/sinergym/utils/wrappers.py +++ b/sinergym/utils/wrappers.py @@ -1204,10 +1204,11 @@ class WandBLogger(gym.Wrapper): def __init__(self, env: Env, - entity: str, - project_name: str, + entity: Optional[str] = None, + project_name: Optional[str] = None, run_name: Optional[str] = None, group: Optional[str] = None, + job_type: Optional[str] = None, tags: Optional[List[str]] = None, save_code: bool = False, dump_frequency: int = 1000, @@ -1228,9 +1229,11 @@ def __init__(self, Args: env (Env): Original Sinergym environment. - entity (str): The entity to which the project belongs. - project_name (str): The project name. + entity (Optional[str]): The entity to which the project belongs. If you are using a previous wandb run, it is not necessary to specify it. Defaults to None. + project_name (Optional[str]): The project name. If you are using a previous wandb run, it is not necessary to specify it. Defaults to None. run_name (Optional[str]): The name of the run. Defaults to None (Sinergym env name + wandb unique identifier). + group (Optional[str]): The name of the group to which the run belongs. Defaults to None. + job_type (Optional[str]): The type of job. Defaults to None. tags (Optional[List[str]]): List of tags for the run. Defaults to None. save_code (bool): Whether to save the code in the run. Defaults to False. dump_frequency (int): Frequency to dump log in platform. Defaults to 1000. @@ -1257,13 +1260,26 @@ def __init__(self, 'name') + '_' + wandb.util.generate_id() # Init WandB session - self.wandb_run = wandb.init(entity=entity, - project=project_name, - name=run_name, - group=group, - tags=tags, - save_code=save_code, - reinit=False) + # If there is no active run and entity and project has been specified, + # initialize a new one using the parameters + if wandb.run is None and ( + entity is not None and project_name is not None): + self.wandb_run = wandb.init(entity=entity, + project=project_name, + name=run_name, + group=group, + job_type=job_type, + tags=tags, + save_code=save_code, + reinit=False) + # If there is an active run + elif wandb.run is not None: + # Use the active run + self.wandb_run = wandb.run + else: + self.logger.error( + 'Error initializing WandB run, if project and entity are not specified, it should be a previous active wandb run, but it has not been found.') + raise RuntimeError # Wandb finish with env.close flag self.wandb_finish = True @@ -1404,14 +1420,15 @@ def wandb_log(self) -> None: def wandb_log_summary(self) -> None: """Log episode summary in WandB platform. """ - # Get information from logger of LoggerWrapper - episode_summary = self.get_wrapper_attr( - 'get_episode_summary')() - # Deleting excluded keys - episode_summary = {key: value for key, value in episode_summary.items( - ) if key not in self.get_wrapper_attr('excluded_episode_summary_keys')} - # Log summary data in WandB - self._log_data({'episode_summaries': episode_summary}) + if len(self.get_wrapper_attr('data_logger').rewards) > 0: + # Get information from logger of LoggerWrapper + episode_summary = self.get_wrapper_attr( + 'get_episode_summary')() + # Deleting excluded keys + episode_summary = {key: value for key, value in episode_summary.items( + ) if key not in self.get_wrapper_attr('excluded_episode_summary_keys')} + # Log summary data in WandB + self._log_data({'episode_summaries': episode_summary}) def save_artifact(self) -> None: """Save sinergym output as artifact in WandB platform. diff --git a/sinergym/version.txt b/sinergym/version.txt index e5b820341..3c8ff8c36 100644 --- a/sinergym/version.txt +++ b/sinergym/version.txt @@ -1 +1 @@ -3.5.0 \ No newline at end of file +3.5.1 \ No newline at end of file