diff --git a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py index 9aa48b07..431bad3d 100644 --- a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py +++ b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py @@ -377,17 +377,16 @@ def stats_callback(type: str, payload: dict) -> dict: agent, critics_info = agent.update_critics( batch, ) - agent = jax.block_until_ready(agent) with timer.context("train"): batch = next(replay_iterator) demo_batch = next(demo_iterator) batch = concat_batches(batch, demo_batch, axis=0) agent, update_info = agent.update_high_utd(batch, utd_ratio=1) - agent = jax.block_until_ready(agent) # publish the updated network if update_steps > 0 and update_steps % (FLAGS.steps_per_update) == 0: + agent = jax.block_until_ready(agent) server.publish_network(agent.state.params) if update_steps % FLAGS.log_period == 0 and wandb_logger: diff --git a/examples/async_cable_route_drq/async_drq_randomized.py b/examples/async_cable_route_drq/async_drq_randomized.py index f53ceb36..ebbf7a48 100644 --- a/examples/async_cable_route_drq/async_drq_randomized.py +++ b/examples/async_cable_route_drq/async_drq_randomized.py @@ -286,17 +286,16 @@ def stats_callback(type: str, payload: dict) -> dict: agent, critics_info = agent.update_critics( batch, ) - agent = jax.block_until_ready(agent) with timer.context("train"): batch = next(replay_iterator) demo_batch = next(demo_iterator) batch = concat_batches(batch, demo_batch, axis=0) agent, update_info = agent.update_high_utd(batch, utd_ratio=1) - agent = jax.block_until_ready(agent) # publish the updated network if step > 0 and step % (FLAGS.steps_per_update) == 0: + agent = jax.block_until_ready(agent) server.publish_network(agent.state.params) if update_steps % FLAGS.log_period == 0 and wandb_logger: diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index 59a4de71..d245ea66 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -229,15 +229,14 @@ def stats_callback(type: str, payload: dict) -> dict: agent, critics_info = agent.update_critics( batch, ) - agent = jax.block_until_ready(agent) with timer.context("train"): batch = next(replay_iterator) agent, update_info = agent.update_high_utd(batch, utd_ratio=1) - agent = jax.block_until_ready(agent) # publish the updated network if step > 0 and step % (FLAGS.steps_per_update) == 0: + agent = jax.block_until_ready(agent) server.publish_network(agent.state.params) if update_steps % FLAGS.log_period == 0 and wandb_logger: diff --git a/examples/async_pcb_insert_drq/async_drq_randomized.py b/examples/async_pcb_insert_drq/async_drq_randomized.py index 2798b6ff..bdc5dac6 100644 --- a/examples/async_pcb_insert_drq/async_drq_randomized.py +++ b/examples/async_pcb_insert_drq/async_drq_randomized.py @@ -281,17 +281,16 @@ def stats_callback(type: str, payload: dict) -> dict: agent, critics_info = agent.update_critics( batch, ) - agent = jax.block_until_ready(agent) with timer.context("train"): batch = next(replay_iterator) demo_batch = next(demo_iterator) batch = concat_batches(batch, demo_batch, axis=0) agent, update_info = agent.update_high_utd(batch, utd_ratio=1) - agent = jax.block_until_ready(agent) # publish the updated network if step > 0 and step % (FLAGS.steps_per_update) == 0: + agent = jax.block_until_ready(agent) server.publish_network(agent.state.params) if update_steps % FLAGS.log_period == 0 and wandb_logger: diff --git a/examples/async_peg_insert_drq/async_drq_randomized.py b/examples/async_peg_insert_drq/async_drq_randomized.py index 20ac5056..58c54412 100644 --- a/examples/async_peg_insert_drq/async_drq_randomized.py +++ b/examples/async_peg_insert_drq/async_drq_randomized.py @@ -280,17 +280,16 @@ def stats_callback(type: str, payload: dict) -> dict: agent, critics_info = agent.update_critics( batch, ) - agent = jax.block_until_ready(agent) with timer.context("train"): batch = next(replay_iterator) demo_batch = next(demo_iterator) batch = concat_batches(batch, demo_batch, axis=0) agent, update_info = agent.update_high_utd(batch, utd_ratio=1) - agent = jax.block_until_ready(agent) # publish the updated network if step > 0 and step % (FLAGS.steps_per_update) == 0: + agent = jax.block_until_ready(agent) server.publish_network(agent.state.params) if update_steps % FLAGS.log_period == 0 and wandb_logger: diff --git a/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py b/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py index 66d06383..81da41a0 100644 --- a/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py +++ b/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py @@ -247,7 +247,6 @@ def stats_callback(type: str, payload: dict) -> dict: agent, critics_info = agent.update_critics( batch, ) - agent = jax.block_until_ready(agent) with timer.context("train"): batch = next(replay_iterator) @@ -255,10 +254,10 @@ def stats_callback(type: str, payload: dict) -> dict: demo_batch = next(demo_iterator) batch = concat_batches(batch, demo_batch, axis=0) agent, update_info = agent.update_high_utd(batch, utd_ratio=1) - agent = jax.block_until_ready(agent) # publish the updated network if step > 0 and step % (FLAGS.steps_per_update) == 0: + agent = jax.block_until_ready(agent) server.publish_network(agent.state.params) if update_steps % FLAGS.log_period == 0 and wandb_logger: