Skip to content

Commit

Permalink
Refactor events, smooth fps
Browse files Browse the repository at this point in the history
  • Loading branch information
Shmuma committed Jul 31, 2019
1 parent 195ba98 commit dbca226
Showing 1 changed file with 42 additions and 26 deletions.
68 changes: 42 additions & 26 deletions ptan/ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,27 @@
import time
from typing import Optional
from ignite.engine import Engine, State
from ignite.engine import Events as EngineEvents
from ignite.engine import Events
from ignite.handlers.timing import Timer


class EndOfEpisodeHandler:
class Events(enum.Enum):
EPISODE_COMPLETED = "episode_completed"
BOUND_REWARD_REACHED = "bound_reward_reached"
class EpisodeEvents(enum.Enum):
EPISODE_COMPLETED = "episode_completed"
BOUND_REWARD_REACHED = "bound_reward_reached"


class EndOfEpisodeHandler:
def __init__(self, exp_source: ptan.experience.ExperienceSource, alpha: float = 0.98,
bound_avg_reward: Optional[float] = None):
self._exp_source = exp_source
self._alpha = alpha
self._bound_avg_reward = bound_avg_reward

def attach(self, engine: Engine):
engine.add_event_handler(EngineEvents.ITERATION_COMPLETED, self)
engine.register_events(*self.Events)
State.event_to_attr[self.Events.EPISODE_COMPLETED] = "episode"
State.event_to_attr[self.Events.BOUND_REWARD_REACHED] = "episode"
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
engine.register_events(*EpisodeEvents)
State.event_to_attr[EpisodeEvents.EPISODE_COMPLETED] = "episode"
State.event_to_attr[EpisodeEvents.BOUND_REWARD_REACHED] = "episode"

def __call__(self, engine: Engine):
for reward, steps in self._exp_source.pop_rewards_steps():
Expand All @@ -32,9 +33,9 @@ def __call__(self, engine: Engine):
engine.state.metrics['reward'] = reward
engine.state.metrics['steps'] = steps
self._update_smoothed_metrics(engine, reward, steps)
engine.fire_event(self.Events.EPISODE_COMPLETED)
engine.fire_event(EpisodeEvents.EPISODE_COMPLETED)
if self._bound_avg_reward is not None and engine.state.metrics['avg_reward'] >= self._bound_avg_reward:
engine.fire_event(self.Events.BOUND_REWARD_REACHED)
engine.fire_event(EpisodeEvents.BOUND_REWARD_REACHED)

def _update_smoothed_metrics(self, engine: Engine, reward: float, steps: int):
for attr_name, val in zip(('avg_reward', 'avg_steps'), (reward, steps)):
Expand All @@ -46,43 +47,58 @@ def _update_smoothed_metrics(self, engine: Engine, reward: float, steps: int):


class EpisodeFPSHandler:
def __init__(self, fps_mul: float = 1.0):
FPS_METRIC = 'fps'
AVG_FPS_METRIC = 'avg_fps'
TIME_PASSED_METRIC = 'time_passed'

def __init__(self, fps_mul: float = 1.0, fps_smooth_alpha: float = 0.98):
self._timer = Timer(average=True)
self._fps_mul = fps_mul
self._started_ts = time.time()
self._fps_smooth_alpha = fps_smooth_alpha

def attach(self, engine: Engine):
self._timer.attach(engine, step=EngineEvents.ITERATION_COMPLETED)
engine.add_event_handler(EndOfEpisodeHandler.Events.EPISODE_COMPLETED, self)
self._timer.attach(engine, step=Events.ITERATION_COMPLETED)
engine.add_event_handler(EpisodeEvents.EPISODE_COMPLETED, self)

def __call__(self, engine: Engine):
t_val = self._timer.value()
if engine.state.iteration > 1:
engine.state.metrics['fps'] = self._fps_mul / t_val
engine.state.metrics['time_passed'] = time.time() - self._started_ts
fps = self._fps_mul / t_val
avg_fps = engine.state.metrics.get(self.AVG_FPS_METRIC)
if avg_fps is None:
avg_fps = fps
else:
avg_fps *= self._fps_smooth_alpha
avg_fps += (1-self._fps_smooth_alpha) * fps
engine.state.metrics[self.AVG_FPS_METRIC] = avg_fps
engine.state.metrics[self.FPS_METRIC] = fps
engine.state.metrics[self.TIME_PASSED_METRIC] = time.time() - self._started_ts
self._timer.reset()


class PeriodEvents(enum.Enum):
ITERS_10_COMPLETED = "iterations_10_completed"
ITERS_100_COMPLETED = "iterations_100_completed"
ITERS_1000_COMPLETED = "iterations_1000_completed"


class PeriodicEvents:
"""
The same as CustomPeriodicEvent from ignite.contrib, but use true amount of iterations,
which is good for TensorBoard
"""
class Events(enum.Enum):
ITERATIONS_10_COMPLETED = "iterations_10_completed"
ITERATIONS_100_COMPLETED = "iterations_100_completed"
ITERATIONS_1000_COMPLETED = "iterations_1000_completed"

INTERVAL_TO_EVENT = {
10: Events.ITERATIONS_10_COMPLETED,
100: Events.ITERATIONS_100_COMPLETED,
1000: Events.ITERATIONS_1000_COMPLETED,
10: PeriodEvents.ITERS_10_COMPLETED,
100: PeriodEvents.ITERS_100_COMPLETED,
1000: PeriodEvents.ITERS_1000_COMPLETED,
}

def attach(self, engine: Engine):
engine.add_event_handler(EngineEvents.ITERATION_COMPLETED, self)
engine.register_events(*self.Events)
for e in self.Events:
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
engine.register_events(*PeriodEvents)
for e in PeriodEvents:
State.event_to_attr[e] = "iteration"

def __call__(self, engine: Engine):
Expand Down

0 comments on commit dbca226

Please sign in to comment.