-
Notifications
You must be signed in to change notification settings - Fork 7
/
callbacks.py
42 lines (31 loc) · 1.3 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from stable_baselines3.common.callbacks import BaseCallback
from tqdm.auto import tqdm
class ProgressBarCallback(BaseCallback):
"""
:param pbar: (tqdm.pbar) Progress bar object
"""
def __init__(self, pbar):
super(ProgressBarCallback, self).__init__()
self._pbar = pbar
def _on_step(self):
# Update the progress bar:
self._pbar.n = self.num_timesteps
self._pbar.update(0)
# this callback uses the 'with' block, allowing for correct initialisation and destruction
class ProgressBarManager(object):
def __init__(self, total_timesteps): # init object with total timesteps
self.pbar = None
self.total_timesteps = total_timesteps
def __enter__(self): # create the progress bar and callback, return the callback
self.pbar = tqdm(total=self.total_timesteps, desc="Steps", leave=False)
return ProgressBarCallback(self.pbar)
def __exit__(self, exc_type, exc_val, exc_tb): # close the callback
self.pbar.n = self.total_timesteps
self.pbar.update(0)
self.pbar.close()
## Callback usage
# model = TD3("MlpPolicy", "Pendulum-v0", verbose=0)
# with ProgressBarManager(
# 2000
# ) as callback: # this the garanties that the tqdm progress bar closes correctly
# model.learn(2000, callback=callback)