diff --git a/examples/async_pcb_insert_drq/async_drq_randomized.py b/examples/async_pcb_insert_drq/async_drq_randomized.py index bdc5dac6..52bf7f83 100644 --- a/examples/async_pcb_insert_drq/async_drq_randomized.py +++ b/examples/async_pcb_insert_drq/async_drq_randomized.py @@ -5,9 +5,12 @@ import jax import jax.numpy as jnp import numpy as np +import pynput +import os import tqdm from absl import app, flags from flax.training import checkpoints +import threading import gymnasium as gym from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics @@ -88,6 +91,29 @@ def print_green(x): return print("\033[92m {}\033[00m".format(x)) +PAUSE_EVENT_FLAG = threading.Event() +PAUSE_EVENT_FLAG.clear() # clear() to continue the actor/learner loop, set() to pause + + +def pause_callback(key): + """Callback for when a key is pressed""" + global PAUSE_EVENT_FLAG + try: + # chosen a rarely used key to avoid conflicts. this listener is always on, even when the program is not in focus + if key == pynput.keyboard.Key.pause: + print("Requested pause training") + # set the PAUSE FLAG to pause the actor/learner loop + PAUSE_EVENT_FLAG.set() + except AttributeError: + # print(f'{key} pressed') + pass + + +listener = pynput.keyboard.Listener( + on_press=pause_callback +) # to enable keyboard based pause +listener.start() + ############################################################################## @@ -95,6 +121,8 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng): """ This is the actor loop, which runs when "--actor" is set to True. """ + global PAUSE_EVENT_FLAG + if FLAGS.eval_checkpoint_step: success_counter = 0 time_list = [] @@ -130,6 +158,18 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng): print(reward) print(f"{success_counter}/{episode + 1}") + # if pause event is requested, pause the actor + if PAUSE_EVENT_FLAG.is_set(): + print("Actor eval loop interrupted") + response = input("Do you want to continue (c), or exit (e)? ") + if response == "c": + # update PAUSE FLAG to continue training + PAUSE_EVENT_FLAG.clear() + print("Continuing") + else: + print("Stopping actor eval") + break + print(f"success rate: {success_counter / FLAGS.eval_n_trajs}") print(f"average time: {np.mean(time_list)}") return # after done eval, return and exit @@ -209,6 +249,28 @@ def update_params(params): stats = {"timer": timer.get_average_times()} client.request("send-stats", stats) + if PAUSE_EVENT_FLAG.is_set(): + print_green("Actor loop interrupted") + response = input( + "Do you want to continue (c), save replay buffer and exit (s) or simply exit (e)? " + ) + if response == "c": + print("Continuing") + PAUSE_EVENT_FLAG.clear() + else: + if response == "s": + print("Saving replay buffer") + data_store.save( + "replay_buffer_actor.npz" + ) # not yet supported for QueuedDataStore + else: + print("Replay buffer not saved") + print("Stopping actor client") + client.stop() + break + + print("Actor loop finished") + ############################################################################## @@ -219,6 +281,7 @@ def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer, wandb_logger=None) """ # To track the step in the training loop update_steps = 0 + global PAUSE_EVENT_FLAG def stats_callback(type: str, payload: dict) -> dict: """Callback for when server receives stats request.""" @@ -305,6 +368,33 @@ def stats_callback(type: str, payload: dict) -> dict: update_steps += 1 + if PAUSE_EVENT_FLAG.is_set(): + print("Learner loop interrupted") + response = input( + "Do you want to continue (c), save training state and exit (s) or simply exit (e)? " + ) + if response == "c": + print("Continuing") + PAUSE_EVENT_FLAG.clear() + else: + if response == "s": + print("Saving learner state") + agent_ckpt = checkpoints.save_checkpoint( + FLAGS.checkpoint_path, agent.state, step=update_steps, keep=100 + ) + replay_buffer.save( + "replay_buffer_learner.npz" + ) # not yet supported for QueuedDataStore + # TODO: save other parts of training state + else: + print("Training state not saved") + print("Stopping learner client") + break + + # Wrap up the learner loop + server.stop() + print("Learner loop finished") + ##############################################################################