Skip to content

Commit

Permalink
Add pause/resume to PCB insertion training (#32)
Browse files Browse the repository at this point in the history
* add pause/resume to PCB insertion training

* fix precommit failures

* add mutex for setting SHOULD_PAUSE

* use threading.Event to pause training
  • Loading branch information
gautams3 authored Mar 22, 2024
1 parent d0bd00a commit 3322c33
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions examples/async_pcb_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import jax
import jax.numpy as jnp
import numpy as np
import pynput
import os
import tqdm
from absl import app, flags
from flax.training import checkpoints
import threading

import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
Expand Down Expand Up @@ -88,13 +91,38 @@ def print_green(x):
return print("\033[92m {}\033[00m".format(x))


PAUSE_EVENT_FLAG = threading.Event()
PAUSE_EVENT_FLAG.clear() # clear() to continue the actor/learner loop, set() to pause


def pause_callback(key):
"""Callback for when a key is pressed"""
global PAUSE_EVENT_FLAG
try:
# chosen a rarely used key to avoid conflicts. this listener is always on, even when the program is not in focus
if key == pynput.keyboard.Key.pause:
print("Requested pause training")
# set the PAUSE FLAG to pause the actor/learner loop
PAUSE_EVENT_FLAG.set()
except AttributeError:
# print(f'{key} pressed')
pass


listener = pynput.keyboard.Listener(
on_press=pause_callback
) # to enable keyboard based pause
listener.start()

##############################################################################


def actor(agent: DrQAgent, data_store, env, sampling_rng):
"""
This is the actor loop, which runs when "--actor" is set to True.
"""
global PAUSE_EVENT_FLAG

if FLAGS.eval_checkpoint_step:
success_counter = 0
time_list = []
Expand Down Expand Up @@ -130,6 +158,18 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng):
print(reward)
print(f"{success_counter}/{episode + 1}")

# if pause event is requested, pause the actor
if PAUSE_EVENT_FLAG.is_set():
print("Actor eval loop interrupted")
response = input("Do you want to continue (c), or exit (e)? ")
if response == "c":
# update PAUSE FLAG to continue training
PAUSE_EVENT_FLAG.clear()
print("Continuing")
else:
print("Stopping actor eval")
break

print(f"success rate: {success_counter / FLAGS.eval_n_trajs}")
print(f"average time: {np.mean(time_list)}")
return # after done eval, return and exit
Expand Down Expand Up @@ -209,6 +249,28 @@ def update_params(params):
stats = {"timer": timer.get_average_times()}
client.request("send-stats", stats)

if PAUSE_EVENT_FLAG.is_set():
print_green("Actor loop interrupted")
response = input(
"Do you want to continue (c), save replay buffer and exit (s) or simply exit (e)? "
)
if response == "c":
print("Continuing")
PAUSE_EVENT_FLAG.clear()
else:
if response == "s":
print("Saving replay buffer")
data_store.save(
"replay_buffer_actor.npz"
) # not yet supported for QueuedDataStore
else:
print("Replay buffer not saved")
print("Stopping actor client")
client.stop()
break

print("Actor loop finished")


##############################################################################

Expand All @@ -219,6 +281,7 @@ def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer, wandb_logger=None)
"""
# To track the step in the training loop
update_steps = 0
global PAUSE_EVENT_FLAG

def stats_callback(type: str, payload: dict) -> dict:
"""Callback for when server receives stats request."""
Expand Down Expand Up @@ -305,6 +368,33 @@ def stats_callback(type: str, payload: dict) -> dict:

update_steps += 1

if PAUSE_EVENT_FLAG.is_set():
print("Learner loop interrupted")
response = input(
"Do you want to continue (c), save training state and exit (s) or simply exit (e)? "
)
if response == "c":
print("Continuing")
PAUSE_EVENT_FLAG.clear()
else:
if response == "s":
print("Saving learner state")
agent_ckpt = checkpoints.save_checkpoint(
FLAGS.checkpoint_path, agent.state, step=update_steps, keep=100
)
replay_buffer.save(
"replay_buffer_learner.npz"
) # not yet supported for QueuedDataStore
# TODO: save other parts of training state
else:
print("Training state not saved")
print("Stopping learner client")
break

# Wrap up the learner loop
server.stop()
print("Learner loop finished")


##############################################################################

Expand Down

0 comments on commit 3322c33

Please sign in to comment.