From 89fec8fd0ef67be39cd9b9f92e3c3e8be519f32c Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 21 Nov 2024 11:56:05 -0500 Subject: [PATCH 1/8] Use jax.jit for sharding initial steps Apply it to the MNIST workload and the Nesterov optimizer. --- algorithmic_efficiency/checkpoint_utils.py | 3 - algorithmic_efficiency/data_utils.py | 5 +- algorithmic_efficiency/sharding_utils.py | 62 ++++++++++++++++++ .../workloads/mnist/mnist_jax/workload.py | 19 +++--- .../nesterov/jax/submission.py | 63 ++++++++++++++----- 5 files changed, 121 insertions(+), 31 deletions(-) create mode 100644 algorithmic_efficiency/sharding_utils.py diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 29c1a821e..c98799ca7 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -193,10 +193,7 @@ def save_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - model_params = jax.device_get(jax_utils.unreplicate(model_params)) opt_state, _ = optimizer_state - opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) - model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( model_params, diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 901f0b582..557b4a68d 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -60,10 +60,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - # Reshape (global_batch_size, ...) to - # (local_device_count, per_device_batch_size, ...). - # Assumes that `global_batch_size % local_device_count == 0`. - return x.reshape((local_device_count, -1, *x.shape[1:])) + return x return jax.tree_map(_prepare, batch) diff --git a/algorithmic_efficiency/sharding_utils.py b/algorithmic_efficiency/sharding_utils.py new file mode 100644 index 000000000..4950f243a --- /dev/null +++ b/algorithmic_efficiency/sharding_utils.py @@ -0,0 +1,62 @@ +"""Utilities for dealing with sharding in JAX.""" + +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec + + +def get_mesh() -> jax.sharding.Mesh: + """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + return jax.sharding.Mesh(jax.devices(), ("batch",)) + +def get_replicated_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + +def get_naive_sharding_spec(mesh=None): + """Returns a sharding spec that shards data along the first axis.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec("batch")) + + +def get_naive_sharding(x, mesh=None): + """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" + if mesh is None: + mesh = get_mesh() + grid_size = mesh.shape["batch"] + if x.shape[0] % grid_size == 0: + return NamedSharding(mesh, PartitionSpec("batch")) + else: + return NamedSharding(mesh, PartitionSpec()) + + +def shard_params(params, mesh=None): + """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map( + lambda x: jax.device_put(x, get_naive_sharding(x)), params + ) + + +def get_sharding_tree(params, mesh=None): + """Returns a sharding tree for a parameter tree.""" + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + + +def get_empty_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + + +def disp_shard_info(x: jax.Array): + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print( + f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n" + ) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index efbd73e33..11429bae6 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -10,7 +10,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils +from algorithmic_efficiency import param_utils, sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload @@ -46,7 +46,7 @@ def init_model_fn( train=True)['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' @@ -101,10 +101,13 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + static_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, @@ -125,11 +128,11 @@ def _eval_model( (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} - metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + total_metrics = {'accuracy': total_metrics['accuracy'].item() / num_examples, 'loss': total_metrics['loss'].item() / num_examples} + return total_metrics diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f3b0aeed4..817138c6b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -6,10 +6,11 @@ from flax import jax_utils import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algorithmic_efficiency import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -37,7 +38,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( @@ -87,13 +88,13 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, +# @functools.partial( +# jax.pmap, +# axis_name='batch', +# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), +# static_broadcasted_argnums=(0, 1), +# donate_argnums=(2, 3, 4)) +def train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -124,12 +125,9 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # # Get correct global mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -170,7 +168,40 @@ def update_params(workload: spec.Workload, grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, + + mesh = sharding_utils.get_mesh() + # Create shardings for each argument + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # per_device_rngs + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings + ) + outputs = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -185,8 +216,8 @@ def update_params(workload: spec.Workload, if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state From c53729c53160754bb3dd60d5641d5bdaaa3832c4 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 21 Nov 2024 12:14:09 -0500 Subject: [PATCH 2/8] Use jax.jit for adamw --- .../paper_baselines/adamw/jax/submission.py | 97 ++++++++++++------- 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 80a963600..71a677b2f 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -6,10 +6,11 @@ from flax import jax_utils import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algorithmic_efficiency import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -50,24 +51,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -90,9 +85,8 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + + # Compute local loss and gradients loss = summed_loss / n_valid_examples grad = jax.tree_map(lambda x: x / n_valid_examples, grad) @@ -105,7 +99,7 @@ def _loss_fn(params): grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + current_param_container) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm @@ -136,23 +130,58 @@ def update_params(workload: spec.Workload, grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + # Set up mesh and sharding + mesh = sharding_utils.get_mesh() + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Define input and output shardings + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings + ) + + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state From ef6af03a1818e0d3c863a391222c6a44ad0b758e Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 9 Dec 2024 01:53:50 -0500 Subject: [PATCH 3/8] Pass yapf checks --- algorithmic_efficiency/sharding_utils.py | 71 +++++++++---------- .../workloads/mnist/mnist_jax/workload.py | 16 +++-- .../workloads/mnist/workload.py | 7 +- .../paper_baselines/adamw/jax/submission.py | 19 +++-- .../nesterov/jax/submission.py | 59 +++++++-------- submission_runner.py | 4 +- 6 files changed, 88 insertions(+), 88 deletions(-) diff --git a/algorithmic_efficiency/sharding_utils.py b/algorithmic_efficiency/sharding_utils.py index 4950f243a..62a441bc9 100644 --- a/algorithmic_efficiency/sharding_utils.py +++ b/algorithmic_efficiency/sharding_utils.py @@ -5,58 +5,57 @@ def get_mesh() -> jax.sharding.Mesh: - """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" - return jax.sharding.Mesh(jax.devices(), ("batch",)) + """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + return jax.sharding.Mesh(jax.devices(), ("batch",)) + def get_replicated_sharding(mesh=None): - """Returns a sharding spec that replicates data across all devices.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec()) + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + def get_naive_sharding_spec(mesh=None): - """Returns a sharding spec that shards data along the first axis.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec("batch")) + """Returns a sharding spec that shards data along the first axis.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec("batch")) def get_naive_sharding(x, mesh=None): - """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" - if mesh is None: - mesh = get_mesh() - grid_size = mesh.shape["batch"] - if x.shape[0] % grid_size == 0: - return NamedSharding(mesh, PartitionSpec("batch")) - else: - return NamedSharding(mesh, PartitionSpec()) + """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" + if mesh is None: + mesh = get_mesh() + grid_size = mesh.shape["batch"] + if x.shape[0] % grid_size == 0: + return NamedSharding(mesh, PartitionSpec("batch")) + else: + return NamedSharding(mesh, PartitionSpec()) def shard_params(params, mesh=None): - """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" - if mesh is None: - mesh = get_mesh() - return jax.tree_util.tree_map( - lambda x: jax.device_put(x, get_naive_sharding(x)), params - ) + """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map( + lambda x: jax.device_put(x, get_naive_sharding(x)), params) def get_sharding_tree(params, mesh=None): - """Returns a sharding tree for a parameter tree.""" - return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + """Returns a sharding tree for a parameter tree.""" + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) def get_empty_sharding(mesh=None): - """Returns a sharding spec that replicates data across all devices.""" - if mesh is None: - mesh = get_mesh() - return NamedSharding(mesh, PartitionSpec()) + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) def disp_shard_info(x: jax.Array): - """Displays shard info of a jax array.""" - for shard in x.addressable_shards: - print( - f"shard.device: {shard.device}, index: {shard.index}, replica_id:" - f" {shard.replica_id}.\n" - ) + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n") diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index 11429bae6..b0a52d77f 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -102,11 +102,12 @@ def loss_fn( @functools.partial( jax.jit, - in_shardings=(sharding_utils.get_replicated_sharding(), # params - sharding_utils.get_naive_sharding_spec(), # batch - sharding_utils.get_replicated_sharding(), # model_state - sharding_utils.get_naive_sharding_spec(), # rng - ), + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), static_argnums=(0,)) def _eval_model( self, @@ -134,5 +135,8 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - total_metrics = {'accuracy': total_metrics['accuracy'].item() / num_examples, 'loss': total_metrics['loss'].item() / num_examples} + total_metrics = { + 'accuracy': total_metrics['accuracy'].item() / num_examples, + 'loss': total_metrics['loss'].item() / num_examples + } return total_metrics diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..ad950b869 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -46,8 +46,7 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + 'targets': x['label'],}) is_train = split == 'train' if cache: @@ -214,8 +213,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 71a677b2f..40be6dc58 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -162,18 +162,17 @@ def update_params(workload: spec.Workload, static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings - ) + out_shardings=out_shardings) outputs = jitted_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 817138c6b..a24e3baab 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -95,14 +95,14 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): # static_broadcasted_argnums=(0, 1), # donate_argnums=(2, 3, 4)) def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -178,20 +178,20 @@ def update_params(workload: spec.Workload, arg_shardings = ( # workload is static # opt_update_fn is static - replicated, # model_state - replicated, # optimizer_state - replicated, # current_param_container - sharded, # batch - sharded, # per_device_rngs + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # per_device_rngs replicated, # grad_clip - replicated # label_smoothing + replicated # label_smoothing ) out_shardings = ( - replicated, # new_optimizer_state - replicated, # updated_params - replicated, # new_model_state - replicated, # loss - replicated # grad_norm + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm ) # Jit with shardings jitted_train_step = jax.jit( @@ -199,17 +199,16 @@ def update_params(workload: spec.Workload, static_argnums=(0, 1), donate_argnums=(2, 3, 4), in_shardings=arg_shardings, - out_shardings=out_shardings - ) + out_shardings=out_shardings) outputs = jitted_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. @@ -246,6 +245,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..d40b37bf4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -377,8 +377,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() From e6037d6ef49f41abb0d74c44545a40a0f2d8c109 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 9 Dec 2024 01:58:03 -0500 Subject: [PATCH 4/8] CIFAR workload sharding --- .../cifar/cifar_jax/input_pipeline.py | 4 +- .../workloads/cifar/cifar_jax/workload.py | 143 +++++++++++------- 2 files changed, 87 insertions(+), 60 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py index 3e6a68844..1868dde6e 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,7 +8,6 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds @@ -171,5 +170,6 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) + # FIXME(rka97): Figure out how to do prefetching+sharding. + # it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..cfafd1145 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -3,7 +3,6 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax import linen as nn import jax from jax import lax @@ -12,6 +11,7 @@ import tensorflow_datasets as tfds from algorithmic_efficiency import param_utils +from algorithmic_efficiency import sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.cifar.cifar_jax import models from algorithmic_efficiency.workloads.cifar.cifar_jax.input_pipeline import \ @@ -28,15 +28,16 @@ def _build_cifar_dataset( data_dir: str, batch_size: int, cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: - ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) - train = split == 'train' + data_dir = data_dir + "/cifar10" + ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir) + train = split == "train" assert self.num_train_examples + self.num_validation_examples == 50000 - if split in ['train', 'eval_train']: - split = f'train[:{self.num_train_examples}]' - elif split == 'validation': - split = f'train[{self.num_train_examples}:]' + if split in ["train", "eval_train"]: + split = f"train[:{self.num_train_examples}]" + elif split == "validation": + split = f"train[{self.num_train_examples}:]" ds = create_input_iter( split, ds_builder, @@ -48,7 +49,8 @@ def _build_cifar_dataset( self.padding_size, train=train, cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset) + repeat_final_dataset=repeat_final_dataset, + ) return ds def _build_input_queue( @@ -59,7 +61,8 @@ def _build_input_queue( global_batch_size: int, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches return self._build_cifar_dataset(data_rng, split, @@ -74,34 +77,35 @@ def sync_batch_stats( # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') + avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x") new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + {"batch_stats": avg_fn(model_state["batch_stats"])}) return new_model_state def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate - model_cls = getattr(models, 'ResNet18') + model_cls = getattr(models, "ResNet18") model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({'params': rng}, + variables = jax.jit(model.init)({"params": rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = variables.pop("params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + # model_state = jax_utils.replicate(model_state) + # params = jax_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: - return param_key == 'Dense_0' + return param_key == "Dense_0" def model_fn( self, @@ -110,23 +114,26 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng - variables = {'params': params, **model_state} + variables = {"params": params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=["batch_stats"], + ) return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + ) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in @@ -136,13 +143,15 @@ def loss_fn( label_batch: spec.Tensor, # Dense or one-hot labels. logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ + Return {'summed': scalar summed loss, + 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( @@ -155,51 +164,69 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + "summed": summed_loss, + "n_valid_examples": n_valid_examples, + "per_example": per_example_losses, } def _compute_metrics(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor) -> Dict[str, spec.Tensor]: - summed_loss = self.loss_fn(labels, logits, weights)['summed'] + summed_loss = self.loss_fn(labels, logits, weights)["summed"] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) - metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, - } - metrics = lax.psum(metrics, axis_name='batch') - return metrics + return jnp.array(summed_loss), jnp.array(accuracy) - @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + ) + def _per_device_eval_model( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + weights = batch.get("weights") + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch["targets"], weights) + + losses, accuracies = _per_device_eval_model(params, batch, model_state, rng) + metrics = { + "loss": + jnp.mean(losses, axis=0) if losses.ndim > 0 else losses, + "accuracy": + (jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies + ), + } + return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree_map(lambda x: x / num_examples, total_metrics) From 0698e3418acd090f9207243acee969dfdd80056d Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 7 Jan 2025 21:18:44 +0000 Subject: [PATCH 5/8] librispeech_conformer now running Still need to test out (a) output losses, (b) speed, and (c) look into other librispeech. --- .../librispeech_jax/workload.py | 101 +++++++++++++----- .../nesterov/jax/submission.py | 5 +- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..4bcb711f5 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -6,6 +6,9 @@ import flax.linen as nn import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P + +from algorithmic_efficiency import sharding_utils import jax.numpy as jnp import numpy as np import optax @@ -21,7 +24,6 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ models - class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): def __init__(self, @@ -93,8 +95,16 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + + # Add sharding + mesh = sharding_utils.get_mesh() + params = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + params) + model_state = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + model_state) + return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -176,6 +186,7 @@ def _build_input_queue( 'targets': (targets.numpy(), target_paddings.numpy()), } + # Use data_utils.shard_and_maybe_pad_np to handle sharding padded_batch = data_utils.shard_and_maybe_pad_np( numpy_batch, padding_value=1.0) yield padded_batch @@ -300,11 +311,16 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) - def eval_step_pmapped( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + out_shardings=sharding_utils.get_naive_sharding_spec(), + static_argnums=(0,)) + def _eval_step( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -322,13 +338,39 @@ def eval_step_pmapped( loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) targets, target_paddings = batch['targets'] - return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') + # Convert metrics bundle to dictionary + metrics_dict = { + 'loss_per_example': loss['per_example'], + 'decoded': decoded, + 'decoded_paddings': decoded_paddings, + 'targets': targets, + 'target_paddings': target_paddings, + 'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + } + return metrics_dict + + def eval_step( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): + """Evaluates the model and returns a metrics bundle.""" + metrics_dict = self._eval_step(params, batch, model_state, rng) + + # Convert dictionary back to metrics bundle + metrics = self.metrics_bundle.single_from_model_output( + loss_dict={ + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + }, + decoded=metrics_dict['decoded'], + decoded_paddings=metrics_dict['decoded_paddings'], + targets=metrics_dict['targets'], + target_paddings=metrics_dict['target_paddings']) + + return metrics def _eval_model_on_split(self, split: str, @@ -353,10 +395,10 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step(params, + eval_batch, + model_state, + rng) if metrics_report is None: metrics_report = computed_metrics @@ -368,15 +410,22 @@ def _eval_model_on_split(self, return computed_metrics + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # model_state + ), + out_shardings=sharding_utils.get_replicated_sharding(), + static_argnums=(0,) + ) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) - return new_model_state + """Sync batch statistics across replicas.""" + # Replace pmean with direct mean across devices + new_batch_stats = jax.tree_map( + lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) + return model_state.copy({'batch_stats': new_batch_stats}) class LibriSpeechConformerAttentionTemperatureWorkload( diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index a24e3baab..6a903fd7d 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -159,7 +159,6 @@ def update_params(workload: spec.Workload, del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -182,7 +181,7 @@ def update_params(workload: spec.Workload, replicated, # optimizer_state replicated, # current_param_container sharded, # batch - sharded, # per_device_rngs + replicated, # rngs replicated, # grad_clip replicated # label_smoothing ) @@ -206,7 +205,7 @@ def update_params(workload: spec.Workload, optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs From affbf2a9205bdb3a878c93fa5cbe1a40d01ce793 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 16 Jan 2025 16:59:10 +0000 Subject: [PATCH 6/8] fix formatting --- algorithmic_efficiency/sharding_utils.py | 4 +- .../librispeech_jax/models.py | 10 ++- .../librispeech_jax/spectrum_augmenter.py | 4 +- .../librispeech_jax/workload.py | 69 ++++++++++--------- .../librispeech_jax/models.py | 10 ++- .../workloads/mnist/mnist_jax/workload.py | 3 +- .../paper_baselines/adamw/jax/submission.py | 20 +++--- .../nesterov/jax/submission.py | 9 +-- submission_runner.py | 3 + 9 files changed, 69 insertions(+), 63 deletions(-) diff --git a/algorithmic_efficiency/sharding_utils.py b/algorithmic_efficiency/sharding_utils.py index 62a441bc9..93a4dd53f 100644 --- a/algorithmic_efficiency/sharding_utils.py +++ b/algorithmic_efficiency/sharding_utils.py @@ -1,7 +1,9 @@ """Utilities for dealing with sharding in JAX.""" import jax -from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec def get_mesh() -> jax.sharding.Mesh: diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..db8cbc70a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,12 +442,10 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..c16740629 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights + < multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 4bcb711f5..27ba9b6c3 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -6,16 +6,16 @@ import flax.linen as nn import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P - -from algorithmic_efficiency import sharding_utils import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import numpy as np import optax import torch from algorithmic_efficiency import data_utils from algorithmic_efficiency import param_utils +from algorithmic_efficiency import sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.librispeech_conformer import metrics from algorithmic_efficiency.workloads.librispeech_conformer import workload @@ -24,6 +24,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ models + class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): def __init__(self, @@ -95,16 +96,18 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - + # Add sharding mesh = sharding_utils.get_mesh() params = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), params) model_state = jax.tree_map( - lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)), + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), model_state) - + return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -340,36 +343,41 @@ def _eval_step( targets, target_paddings = batch['targets'] # Convert metrics bundle to dictionary metrics_dict = { - 'loss_per_example': loss['per_example'], - 'decoded': decoded, - 'decoded_paddings': decoded_paddings, - 'targets': targets, - 'target_paddings': target_paddings, - 'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + 'loss_per_example': + loss['per_example'], + 'decoded': + decoded, + 'decoded_paddings': + decoded_paddings, + 'targets': + targets, + 'target_paddings': + target_paddings, + 'n_valid_examples': + jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] } return metrics_dict - def eval_step( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState): + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): """Evaluates the model and returns a metrics bundle.""" metrics_dict = self._eval_step(params, batch, model_state, rng) - + # Convert dictionary back to metrics bundle metrics = self.metrics_bundle.single_from_model_output( loss_dict={ - 'summed': metrics_dict['loss_per_example'].sum(), - 'per_example': metrics_dict['loss_per_example'], - 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() }, decoded=metrics_dict['decoded'], decoded_paddings=metrics_dict['decoded_paddings'], targets=metrics_dict['targets'], target_paddings=metrics_dict['target_paddings']) - + return metrics def _eval_model_on_split(self, @@ -395,10 +403,7 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step(params, - eval_batch, - model_state, - rng) + computed_metrics = self.eval_step(params, eval_batch, model_state, rng) if metrics_report is None: metrics_report = computed_metrics @@ -416,15 +421,13 @@ def _eval_model_on_split(self, sharding_utils.get_replicated_sharding(), # model_state ), out_shardings=sharding_utils.get_replicated_sharding(), - static_argnums=(0,) - ) + static_argnums=(0,)) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: """Sync batch statistics across replicas.""" # Replace pmean with direct mean across devices - new_batch_stats = jax.tree_map( - lambda x: jnp.mean(x, axis=0), - model_state['batch_stats']) + new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) return model_state.copy({'batch_stats': new_batch_stats}) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index f9eb732e9..c2fe540a6 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,12 +273,10 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index b0a52d77f..b57dd72dd 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -10,7 +10,8 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils, sharding_utils +from algorithmic_efficiency import param_utils +from algorithmic_efficiency import sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 40be6dc58..cf68b6143 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -6,11 +6,13 @@ from flax import jax_utils import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax -from algorithmic_efficiency import spec, sharding_utils +from algorithmic_efficiency import sharding_utils +from algorithmic_efficiency import spec _GRAD_CLIP_EPS = 1e-6 @@ -121,7 +123,6 @@ def update_params(workload: spec.Workload, del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -143,17 +144,17 @@ def update_params(workload: spec.Workload, replicated, # model_state replicated, # optimizer_state replicated, # current_param_container - sharded, # batch - sharded, # rng + sharded, # batch + replicated, # rng replicated, # grad_clip - replicated # label_smoothing + replicated # label_smoothing ) out_shardings = ( replicated, # new_optimizer_state replicated, # updated_params replicated, # new_model_state replicated, # loss - replicated # grad_norm + replicated # grad_norm ) # Jit with shardings @@ -164,16 +165,15 @@ def update_params(workload: spec.Workload, in_shardings=arg_shardings, out_shardings=out_shardings) - outputs = jitted_train_step(workload, + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 6a903fd7d..e832c06ac 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -6,11 +6,13 @@ from flax import jax_utils import jax from jax import lax -from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax -from algorithmic_efficiency import spec, sharding_utils +from algorithmic_efficiency import sharding_utils +from algorithmic_efficiency import spec _GRAD_CLIP_EPS = 1e-6 @@ -199,7 +201,7 @@ def update_params(workload: spec.Workload, donate_argnums=(2, 3, 4), in_shardings=arg_shardings, out_shardings=out_shardings) - outputs = jitted_train_step(workload, + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, @@ -208,7 +210,6 @@ def update_params(workload: spec.Workload, rng, grad_clip, label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: diff --git a/submission_runner.py b/submission_runner.py index d40b37bf4..08b3f13b6 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,6 +28,9 @@ from absl import flags from absl import logging import jax + +jax.config.update('jax_default_prng_impl', 'threefry2x32') +jax.config.update('jax_threefry_partitionable', True) import torch import torch.distributed as dist From 52798078252e03234c99deacd0cffbad1f5bf638 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 16 Jan 2025 17:26:06 +0000 Subject: [PATCH 7/8] shard default --- algorithmic_efficiency/data_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 557b4a68d..fddbfb299 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -12,6 +12,7 @@ from torch.utils.data import Sampler from algorithmic_efficiency import spec +from algorithmic_efficiency import sharding_utils def shard_and_maybe_pad_np( @@ -60,7 +61,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - return x + return jax.device_put(x, sharding_utils.get_naive_sharding_spec()) # x return jax.tree_map(_prepare, batch) From be9a68a89bce337dcae3a857df74860dc8366358 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 16 Jan 2025 18:30:48 +0000 Subject: [PATCH 8/8] start imagenet --- algorithmic_efficiency/data_utils.py | 4 +- .../imagenet_jax/input_pipeline.py | 2 +- .../imagenet_jax/test_imagenet_model_jax.py | 0 .../imagenet_resnet/imagenet_jax/workload.py | 50 +++++++++++++------ 4 files changed, 37 insertions(+), 19 deletions(-) create mode 100644 algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/test_imagenet_model_jax.py diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index fddbfb299..2a7a3c74c 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -51,7 +51,7 @@ def shard_and_maybe_pad_np( weights = batch.get('weights') # The weights will also be padded. batch['weights'] = np.ones(mask_shape) if weights is None else weights - + naive_sharding_spec = sharding_utils.get_naive_sharding_spec() def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. if not isinstance(x, np.ndarray): @@ -61,7 +61,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - return jax.device_put(x, sharding_utils.get_naive_sharding_spec()) # x + return jax.device_put(x, naive_sharding_spec) return jax.tree_map(_prepare, batch) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 422eb9f7a..556f0a4a1 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -400,6 +400,6 @@ def create_input_iter(split: str, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. - it = jax_utils.prefetch_to_device(it, 2) + # it = jax_utils.prefetch_to_device(it, 2) return iter(it) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/test_imagenet_model_jax.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/test_imagenet_model_jax.py new file mode 100644 index 000000000..e69de29bb diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..5d53ccbed 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -17,6 +17,7 @@ import optax import tensorflow_datasets as tfds +from algorithmic_efficiency import sharding_utils from algorithmic_efficiency import param_utils from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency import spec @@ -72,16 +73,20 @@ def _build_dataset( use_randaug=use_randaug) return ds + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # model_state + ), + out_shardings=sharding_utils.get_replicated_sharding(), + static_argnums=(0,)) def sync_batch_stats( self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) - return new_model_state + """Sync batch statistics across replicas.""" + new_batch_stats = jax.tree_map(lambda x: jnp.mean(x, axis=0), + model_state['batch_stats']) + return model_state.copy({'batch_stats': new_batch_stats}) + def init_model_fn( self, @@ -114,18 +119,30 @@ def init_model_fn( model_state, params = variables.pop('params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + mesh = sharding_utils.get_mesh() + params = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), + params) + model_state = jax.tree_map( + lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh) + ), + model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_replicated_sharding(), # rng + ), + static_argnums=(0,), + out_shardings=sharding_utils.get_naive_sharding_spec()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -215,7 +232,7 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - metrics = lax.psum(metrics, axis_name='batch') + # metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -249,11 +266,12 @@ def _eval_model_on_split(self, eval_rng = prng.fold_in(eval_rng, bi) step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) - # We already average these metrics across devices inside _compute_metrics. synced_metrics = self._eval_model(params, batch, model_state, step_eval_rngs) + # Sum up the synced metrics + synced_metrics = jax.tree_map(lambda x: jnp.sum(x, axis=0), synced_metrics) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0