You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
what i want is for each matrix to be sharded 'vertically' if its access index is even, 'horizontally' if its access index is odd (approach inspired by https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). It is more appropriate to perform this sharding according to how the weights are accessed than according to how they are created in order to minimize I/O overhead and associated idleness in the forward pass.
However, this code does not work! Everything is just on TPU 0 instead? That is after running it, everything is still placed on the first device:
shard_getter=ShardGetter()
@hk.without_apply_rng@hk.transformdefobjective(inputs: Batch):
# ... do some application-specific, error-computing stuff...withhk.custom_getter(shard_getter):
loss=model(inputs)
returnloss@jax.jitdeftrain_init(rng, inputs):
params=objective.init(rng, inputs)
opt_st=optimizer().init(params)
loss=0.0step=0returnTrainState(params, opt_st, loss, step)
inputs=next(batches)
inputs=jax.device_put(inputs, shards.replicate(-1)) # meticulously arrange everything _just so..._tstate=train_init(jax.device_put(jax.random.PRNGKey(42), shards.replicate()), inputs)
jax.debug.visualize_array_sharding(tstate.params) # should be fully sharded-- somehow not?
What I would like is for this code to apply the sharding constraints specified in ShardGetter.__call__! As it stands, my monkey patch for this limitation is that I just do this:
which works fine. Understand this is mainly an aesthetic concern (correcting it only required adding three lines). Still, mystified as to what seems to be erasing the sharding constraints after the application of the getter interceptor?
The text was updated successfully, but these errors were encountered:
I have this code:
what i want is for each matrix to be sharded 'vertically' if its access index is even, 'horizontally' if its access index is odd (approach inspired by https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). It is more appropriate to perform this sharding according to how the weights are accessed than according to how they are created in order to minimize I/O overhead and associated idleness in the forward pass.
However, this code does not work! Everything is just on TPU 0 instead? That is after running it, everything is still placed on the first device:
What I would like is for this code to apply the sharding constraints specified in
ShardGetter.__call__
! As it stands, my monkey patch for this limitation is that I just do this:which works fine. Understand this is mainly an aesthetic concern (correcting it only required adding three lines). Still, mystified as to what seems to be erasing the sharding constraints after the application of the getter interceptor?
The text was updated successfully, but these errors were encountered: