Skip to content

Commit

Permalink
move jax.block_until_ready to accelerate training (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
ColinQiyangLi authored Feb 26, 2024
1 parent 4c1c0cf commit 3f48278
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -377,17 +377,16 @@ def stats_callback(type: str, payload: dict) -> dict:
agent, critics_info = agent.update_critics(
batch,
)
agent = jax.block_until_ready(agent)

with timer.context("train"):
batch = next(replay_iterator)
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)
agent, update_info = agent.update_high_utd(batch, utd_ratio=1)
agent = jax.block_until_ready(agent)

# publish the updated network
if update_steps > 0 and update_steps % (FLAGS.steps_per_update) == 0:
agent = jax.block_until_ready(agent)
server.publish_network(agent.state.params)

if update_steps % FLAGS.log_period == 0 and wandb_logger:
Expand Down
3 changes: 1 addition & 2 deletions examples/async_cable_route_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,16 @@ def stats_callback(type: str, payload: dict) -> dict:
agent, critics_info = agent.update_critics(
batch,
)
agent = jax.block_until_ready(agent)

with timer.context("train"):
batch = next(replay_iterator)
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)
agent, update_info = agent.update_high_utd(batch, utd_ratio=1)
agent = jax.block_until_ready(agent)

# publish the updated network
if step > 0 and step % (FLAGS.steps_per_update) == 0:
agent = jax.block_until_ready(agent)
server.publish_network(agent.state.params)

if update_steps % FLAGS.log_period == 0 and wandb_logger:
Expand Down
3 changes: 1 addition & 2 deletions examples/async_drq_sim/async_drq_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,14 @@ def stats_callback(type: str, payload: dict) -> dict:
agent, critics_info = agent.update_critics(
batch,
)
agent = jax.block_until_ready(agent)

with timer.context("train"):
batch = next(replay_iterator)
agent, update_info = agent.update_high_utd(batch, utd_ratio=1)
agent = jax.block_until_ready(agent)

# publish the updated network
if step > 0 and step % (FLAGS.steps_per_update) == 0:
agent = jax.block_until_ready(agent)
server.publish_network(agent.state.params)

if update_steps % FLAGS.log_period == 0 and wandb_logger:
Expand Down
3 changes: 1 addition & 2 deletions examples/async_pcb_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,16 @@ def stats_callback(type: str, payload: dict) -> dict:
agent, critics_info = agent.update_critics(
batch,
)
agent = jax.block_until_ready(agent)

with timer.context("train"):
batch = next(replay_iterator)
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)
agent, update_info = agent.update_high_utd(batch, utd_ratio=1)
agent = jax.block_until_ready(agent)

# publish the updated network
if step > 0 and step % (FLAGS.steps_per_update) == 0:
agent = jax.block_until_ready(agent)
server.publish_network(agent.state.params)

if update_steps % FLAGS.log_period == 0 and wandb_logger:
Expand Down
3 changes: 1 addition & 2 deletions examples/async_peg_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,16 @@ def stats_callback(type: str, payload: dict) -> dict:
agent, critics_info = agent.update_critics(
batch,
)
agent = jax.block_until_ready(agent)

with timer.context("train"):
batch = next(replay_iterator)
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)
agent, update_info = agent.update_high_utd(batch, utd_ratio=1)
agent = jax.block_until_ready(agent)

# publish the updated network
if step > 0 and step % (FLAGS.steps_per_update) == 0:
agent = jax.block_until_ready(agent)
server.publish_network(agent.state.params)

if update_steps % FLAGS.log_period == 0 and wandb_logger:
Expand Down
3 changes: 1 addition & 2 deletions examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,17 @@ def stats_callback(type: str, payload: dict) -> dict:
agent, critics_info = agent.update_critics(
batch,
)
agent = jax.block_until_ready(agent)

with timer.context("train"):
batch = next(replay_iterator)
if demo_iterator is not None:
demo_batch = next(demo_iterator)
batch = concat_batches(batch, demo_batch, axis=0)
agent, update_info = agent.update_high_utd(batch, utd_ratio=1)
agent = jax.block_until_ready(agent)

# publish the updated network
if step > 0 and step % (FLAGS.steps_per_update) == 0:
agent = jax.block_until_ready(agent)
server.publish_network(agent.state.params)

if update_steps % FLAGS.log_period == 0 and wandb_logger:
Expand Down

0 comments on commit 3f48278

Please sign in to comment.