diff --git a/bnpm/plotting_helpers.py b/bnpm/plotting_helpers.py index 300c1ef..5bd6e4a 100644 --- a/bnpm/plotting_helpers.py +++ b/bnpm/plotting_helpers.py @@ -798,9 +798,10 @@ def __init__( def save( self, fig, + name_file: Union[str, List[str]]=None, path_save: str=None, dir_save: str=None, - name_file: Union[str, List[str]]=None, + overwrite: bool=None, ): """ Save the figures. @@ -808,6 +809,12 @@ def save( Args: fig (matplotlib.figure.Figure): Figure to save. + name_file (Union[str, List[str]): + Name of the file to save.\n + If None, then the title of the figure is used.\n + Path will be dir_save / name_file.\n + If a list of strings, then elements [:-1] will be subdirectories + and the last element will be the file name. path_save (str): Path to save the figure. Should not contain suffix. @@ -816,12 +823,9 @@ def save( dir_save (str): Directory to save the figure. If None, then the directory specified in the initialization is used. - name_file (Union[str, List[str]): - Name of the file to save.\n - If None, then the title of the figure is used.\n - Path will be dir_save / name_file.\n - If a list of strings, then elements [:-1] will be subdirectories - and the last element will be the file name. + overwrite (bool): + If True, then overwrite the file if it exists. If None, then the + value specified in the initialization is used. """ if not self.enabled: print('RH Warning: Figure_Saver is disabled. Not saving the figure.') if self.verbose > 1 else None @@ -858,10 +862,13 @@ def save( path_save = [path_save] if not isinstance(path_save, list) else path_save + ## Check overwrite + overwrite = self.overwrite if overwrite is None else overwrite + ## Save figure for path, form in zip(path_save, self.format_save): if Path(path).exists(): - if self.overwrite: + if overwrite: print(f'RH Warning: Overwriting file. File: {path} already exists.') if self.verbose > 0 else None else: print(f'RH Warning: Not saving anything. File exists and overwrite==False. {path} already exists.') if self.verbose > 0 else None @@ -904,11 +911,12 @@ def __call__( name_file: str=None, path_save: str=None, dir_save: str=None, + overwrite: bool=None, ): """ Calls save() method. """ - self.save(fig, path_save=path_save, name_file=name_file, dir_save=dir_save) + self.save(fig, path_save=path_save, name_file=name_file, dir_save=dir_save, overwrite=overwrite) def __repr__(self): return f"Figure_Saver(dir_save={self.dir_save}, format={self.format_save}, overwrite={self.overwrite}, kwargs_savefig={self.kwargs_savefig}, verbose={self.verbose})"