diff --git a/ptan/ignite.py b/ptan/ignite.py index 0952b94..61d9b4f 100644 --- a/ptan/ignite.py +++ b/ptan/ignite.py @@ -3,15 +3,16 @@ import time from typing import Optional from ignite.engine import Engine, State -from ignite.engine import Events as EngineEvents +from ignite.engine import Events from ignite.handlers.timing import Timer -class EndOfEpisodeHandler: - class Events(enum.Enum): - EPISODE_COMPLETED = "episode_completed" - BOUND_REWARD_REACHED = "bound_reward_reached" +class EpisodeEvents(enum.Enum): + EPISODE_COMPLETED = "episode_completed" + BOUND_REWARD_REACHED = "bound_reward_reached" + +class EndOfEpisodeHandler: def __init__(self, exp_source: ptan.experience.ExperienceSource, alpha: float = 0.98, bound_avg_reward: Optional[float] = None): self._exp_source = exp_source @@ -19,10 +20,10 @@ def __init__(self, exp_source: ptan.experience.ExperienceSource, alpha: float = self._bound_avg_reward = bound_avg_reward def attach(self, engine: Engine): - engine.add_event_handler(EngineEvents.ITERATION_COMPLETED, self) - engine.register_events(*self.Events) - State.event_to_attr[self.Events.EPISODE_COMPLETED] = "episode" - State.event_to_attr[self.Events.BOUND_REWARD_REACHED] = "episode" + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + engine.register_events(*EpisodeEvents) + State.event_to_attr[EpisodeEvents.EPISODE_COMPLETED] = "episode" + State.event_to_attr[EpisodeEvents.BOUND_REWARD_REACHED] = "episode" def __call__(self, engine: Engine): for reward, steps in self._exp_source.pop_rewards_steps(): @@ -32,9 +33,9 @@ def __call__(self, engine: Engine): engine.state.metrics['reward'] = reward engine.state.metrics['steps'] = steps self._update_smoothed_metrics(engine, reward, steps) - engine.fire_event(self.Events.EPISODE_COMPLETED) + engine.fire_event(EpisodeEvents.EPISODE_COMPLETED) if self._bound_avg_reward is not None and engine.state.metrics['avg_reward'] >= self._bound_avg_reward: - engine.fire_event(self.Events.BOUND_REWARD_REACHED) + engine.fire_event(EpisodeEvents.BOUND_REWARD_REACHED) def _update_smoothed_metrics(self, engine: Engine, reward: float, steps: int): for attr_name, val in zip(('avg_reward', 'avg_steps'), (reward, steps)): @@ -46,43 +47,58 @@ def _update_smoothed_metrics(self, engine: Engine, reward: float, steps: int): class EpisodeFPSHandler: - def __init__(self, fps_mul: float = 1.0): + FPS_METRIC = 'fps' + AVG_FPS_METRIC = 'avg_fps' + TIME_PASSED_METRIC = 'time_passed' + + def __init__(self, fps_mul: float = 1.0, fps_smooth_alpha: float = 0.98): self._timer = Timer(average=True) self._fps_mul = fps_mul self._started_ts = time.time() + self._fps_smooth_alpha = fps_smooth_alpha def attach(self, engine: Engine): - self._timer.attach(engine, step=EngineEvents.ITERATION_COMPLETED) - engine.add_event_handler(EndOfEpisodeHandler.Events.EPISODE_COMPLETED, self) + self._timer.attach(engine, step=Events.ITERATION_COMPLETED) + engine.add_event_handler(EpisodeEvents.EPISODE_COMPLETED, self) def __call__(self, engine: Engine): t_val = self._timer.value() if engine.state.iteration > 1: - engine.state.metrics['fps'] = self._fps_mul / t_val - engine.state.metrics['time_passed'] = time.time() - self._started_ts + fps = self._fps_mul / t_val + avg_fps = engine.state.metrics.get(self.AVG_FPS_METRIC) + if avg_fps is None: + avg_fps = fps + else: + avg_fps *= self._fps_smooth_alpha + avg_fps += (1-self._fps_smooth_alpha) * fps + engine.state.metrics[self.AVG_FPS_METRIC] = avg_fps + engine.state.metrics[self.FPS_METRIC] = fps + engine.state.metrics[self.TIME_PASSED_METRIC] = time.time() - self._started_ts self._timer.reset() +class PeriodEvents(enum.Enum): + ITERS_10_COMPLETED = "iterations_10_completed" + ITERS_100_COMPLETED = "iterations_100_completed" + ITERS_1000_COMPLETED = "iterations_1000_completed" + + class PeriodicEvents: """ The same as CustomPeriodicEvent from ignite.contrib, but use true amount of iterations, which is good for TensorBoard """ - class Events(enum.Enum): - ITERATIONS_10_COMPLETED = "iterations_10_completed" - ITERATIONS_100_COMPLETED = "iterations_100_completed" - ITERATIONS_1000_COMPLETED = "iterations_1000_completed" INTERVAL_TO_EVENT = { - 10: Events.ITERATIONS_10_COMPLETED, - 100: Events.ITERATIONS_100_COMPLETED, - 1000: Events.ITERATIONS_1000_COMPLETED, + 10: PeriodEvents.ITERS_10_COMPLETED, + 100: PeriodEvents.ITERS_100_COMPLETED, + 1000: PeriodEvents.ITERS_1000_COMPLETED, } def attach(self, engine: Engine): - engine.add_event_handler(EngineEvents.ITERATION_COMPLETED, self) - engine.register_events(*self.Events) - for e in self.Events: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + engine.register_events(*PeriodEvents) + for e in PeriodEvents: State.event_to_attr[e] = "iteration" def __call__(self, engine: Engine):