diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 29c1a821e..c98799ca7 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -193,10 +193,7 @@ def save_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - model_params = jax.device_get(jax_utils.unreplicate(model_params)) opt_state, _ = optimizer_state - opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) - model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( model_params, diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 901f0b582..557b4a68d 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -60,10 +60,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - # Reshape (global_batch_size, ...) to - # (local_device_count, per_device_batch_size, ...). - # Assumes that `global_batch_size % local_device_count == 0`. - return x.reshape((local_device_count, -1, *x.shape[1:])) + return x return jax.tree_map(_prepare, batch) diff --git a/algorithmic_efficiency/sharding_utils.py b/algorithmic_efficiency/sharding_utils.py new file mode 100644 index 000000000..62a441bc9 --- /dev/null +++ b/algorithmic_efficiency/sharding_utils.py @@ -0,0 +1,61 @@ +"""Utilities for dealing with sharding in JAX.""" + +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec + + +def get_mesh() -> jax.sharding.Mesh: + """Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh.""" + return jax.sharding.Mesh(jax.devices(), ("batch",)) + + +def get_replicated_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + + +def get_naive_sharding_spec(mesh=None): + """Returns a sharding spec that shards data along the first axis.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec("batch")) + + +def get_naive_sharding(x, mesh=None): + """Given a 1D mesh and a tensor, try to shard along the appropriate axis.""" + if mesh is None: + mesh = get_mesh() + grid_size = mesh.shape["batch"] + if x.shape[0] % grid_size == 0: + return NamedSharding(mesh, PartitionSpec("batch")) + else: + return NamedSharding(mesh, PartitionSpec()) + + +def shard_params(params, mesh=None): + """Shards a parameter tree across all devices with naive sharding (see get_naive_sharding).""" + if mesh is None: + mesh = get_mesh() + return jax.tree_util.tree_map( + lambda x: jax.device_put(x, get_naive_sharding(x)), params) + + +def get_sharding_tree(params, mesh=None): + """Returns a sharding tree for a parameter tree.""" + return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params) + + +def get_empty_sharding(mesh=None): + """Returns a sharding spec that replicates data across all devices.""" + if mesh is None: + mesh = get_mesh() + return NamedSharding(mesh, PartitionSpec()) + + +def disp_shard_info(x: jax.Array): + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n") diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py index 3e6a68844..1868dde6e 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,7 +8,6 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds @@ -171,5 +170,6 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) + # FIXME(rka97): Figure out how to do prefetching+sharding. + # it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..cfafd1145 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -3,7 +3,6 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax import linen as nn import jax from jax import lax @@ -12,6 +11,7 @@ import tensorflow_datasets as tfds from algorithmic_efficiency import param_utils +from algorithmic_efficiency import sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.cifar.cifar_jax import models from algorithmic_efficiency.workloads.cifar.cifar_jax.input_pipeline import \ @@ -28,15 +28,16 @@ def _build_cifar_dataset( data_dir: str, batch_size: int, cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: - ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) - train = split == 'train' + data_dir = data_dir + "/cifar10" + ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir) + train = split == "train" assert self.num_train_examples + self.num_validation_examples == 50000 - if split in ['train', 'eval_train']: - split = f'train[:{self.num_train_examples}]' - elif split == 'validation': - split = f'train[{self.num_train_examples}:]' + if split in ["train", "eval_train"]: + split = f"train[:{self.num_train_examples}]" + elif split == "validation": + split = f"train[{self.num_train_examples}:]" ds = create_input_iter( split, ds_builder, @@ -48,7 +49,8 @@ def _build_cifar_dataset( self.padding_size, train=train, cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset) + repeat_final_dataset=repeat_final_dataset, + ) return ds def _build_input_queue( @@ -59,7 +61,8 @@ def _build_input_queue( global_batch_size: int, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches return self._build_cifar_dataset(data_rng, split, @@ -74,34 +77,35 @@ def sync_batch_stats( # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') + avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x") new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + {"batch_stats": avg_fn(model_state["batch_stats"])}) return new_model_state def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate - model_cls = getattr(models, 'ResNet18') + model_cls = getattr(models, "ResNet18") model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({'params': rng}, + variables = jax.jit(model.init)({"params": rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = variables.pop("params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + # model_state = jax_utils.replicate(model_state) + # params = jax_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: - return param_key == 'Dense_0' + return param_key == "Dense_0" def model_fn( self, @@ -110,23 +114,26 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng - variables = {'params': params, **model_state} + variables = {"params": params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=["batch_stats"], + ) return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + ) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in @@ -136,13 +143,15 @@ def loss_fn( label_batch: spec.Tensor, # Dense or one-hot labels. logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ + Return {'summed': scalar summed loss, + 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( @@ -155,51 +164,69 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + "summed": summed_loss, + "n_valid_examples": n_valid_examples, + "per_example": per_example_losses, } def _compute_metrics(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor) -> Dict[str, spec.Tensor]: - summed_loss = self.loss_fn(labels, logits, weights)['summed'] + summed_loss = self.loss_fn(labels, logits, weights)["summed"] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) - metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, - } - metrics = lax.psum(metrics, axis_name='batch') - return metrics + return jnp.array(summed_loss), jnp.array(accuracy) - @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + ) + def _per_device_eval_model( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + weights = batch.get("weights") + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch["targets"], weights) + + losses, accuracies = _per_device_eval_model(params, batch, model_state, rng) + metrics = { + "loss": + jnp.mean(losses, axis=0) if losses.ndim > 0 else losses, + "accuracy": + (jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies + ), + } + return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree_map(lambda x: x / num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index efbd73e33..b0a52d77f 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -10,7 +10,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils +from algorithmic_efficiency import param_utils, sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload @@ -46,7 +46,7 @@ def init_model_fn( train=True)['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' @@ -101,10 +101,14 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + static_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, @@ -125,11 +129,14 @@ def _eval_model( (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} - metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + total_metrics = { + 'accuracy': total_metrics['accuracy'].item() / num_examples, + 'loss': total_metrics['loss'].item() / num_examples + } + return total_metrics diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..ad950b869 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -46,8 +46,7 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + 'targets': x['label'],}) is_train = split == 'train' if cache: @@ -214,8 +213,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 80a963600..40be6dc58 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -6,10 +6,11 @@ from flax import jax_utils import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algorithmic_efficiency import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -50,24 +51,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -90,9 +85,8 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + + # Compute local loss and gradients loss = summed_loss / n_valid_examples grad = jax.tree_map(lambda x: x / n_valid_examples, grad) @@ -105,7 +99,7 @@ def _loss_fn(params): grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + current_param_container) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm @@ -136,23 +130,57 @@ def update_params(workload: spec.Workload, grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + # Set up mesh and sharding + mesh = sharding_utils.get_mesh() + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Define input and output shardings + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f3b0aeed4..a24e3baab 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -6,10 +6,11 @@ from flax import jax_utils import jax from jax import lax +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algorithmic_efficiency import spec, sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -37,7 +38,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( @@ -87,21 +88,21 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +# @functools.partial( +# jax.pmap, +# axis_name='batch', +# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), +# static_broadcasted_argnums=(0, 1), +# donate_argnums=(2, 3, 4)) +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -124,12 +125,9 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # # Get correct global mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -170,23 +168,55 @@ def update_params(workload: spec.Workload, grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + mesh = sharding_utils.get_mesh() + # Create shardings for each argument + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + sharded, # per_device_rngs + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -215,6 +245,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..d40b37bf4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -377,8 +377,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem()