Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662992054
  • Loading branch information
pschuh authored and The swirl_dynamics Authors committed Aug 14, 2024
1 parent f6884a4 commit 456cab2
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions swirl_dynamics/templates/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,27 @@ def on_train_begin(self, trainer: Trainer) -> None:
self.last_eval_metric = {}
# retrieve from existing checkpoints if possible
if self.ckpt_manager.latest_step() is not None:

def to_shard_shape_dtype(x):
aval = jax.api_util.shaped_abstractify(x)
if trainer.is_distributed:
return jax.ShapeDtypeStruct(aval.shape[1:], dtype=aval.dtype)
else:
return jax.ShapeDtypeStruct(aval.shape, dtype=aval.dtype)

# Load a single shard and then replicate explicitly.
restored = self.ckpt_manager.restore(
self.ckpt_manager.latest_step(),
args=ocp.args.Composite(**{
self.train_state_field: ocp.args.StandardRestore(
item=trainer.train_state
item=jax.tree.map(to_shard_shape_dtype, trainer.train_state)
)
}),
)
# The restored train state gets automatically replicated according to the
# shape of `trainer.train_state` so we don't need to do it manually.
trainer.train_state = getattr(restored, self.train_state_field)

trainer.train_state = trainer._maybe_replicate( # pylint: disable=protected-access
getattr(restored, self.train_state_field)
)

def on_train_batches_end(
self, trainer: Trainer, train_metrics: ComputedMetrics
Expand Down

0 comments on commit 456cab2

Please sign in to comment.