From e6037d6ef49f41abb0d74c44545a40a0f2d8c109 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 9 Dec 2024 01:58:03 -0500 Subject: [PATCH] CIFAR workload sharding --- .../cifar/cifar_jax/input_pipeline.py | 4 +- .../workloads/cifar/cifar_jax/workload.py | 143 +++++++++++------- 2 files changed, 87 insertions(+), 60 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py index 3e6a68844..1868dde6e 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,7 +8,6 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds @@ -171,5 +170,6 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) + # FIXME(rka97): Figure out how to do prefetching+sharding. + # it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..cfafd1145 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -3,7 +3,6 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax import linen as nn import jax from jax import lax @@ -12,6 +11,7 @@ import tensorflow_datasets as tfds from algorithmic_efficiency import param_utils +from algorithmic_efficiency import sharding_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.cifar.cifar_jax import models from algorithmic_efficiency.workloads.cifar.cifar_jax.input_pipeline import \ @@ -28,15 +28,16 @@ def _build_cifar_dataset( data_dir: str, batch_size: int, cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: - ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) - train = split == 'train' + data_dir = data_dir + "/cifar10" + ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir) + train = split == "train" assert self.num_train_examples + self.num_validation_examples == 50000 - if split in ['train', 'eval_train']: - split = f'train[:{self.num_train_examples}]' - elif split == 'validation': - split = f'train[{self.num_train_examples}:]' + if split in ["train", "eval_train"]: + split = f"train[:{self.num_train_examples}]" + elif split == "validation": + split = f"train[{self.num_train_examples}:]" ds = create_input_iter( split, ds_builder, @@ -48,7 +49,8 @@ def _build_cifar_dataset( self.padding_size, train=train, cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset) + repeat_final_dataset=repeat_final_dataset, + ) return ds def _build_input_queue( @@ -59,7 +61,8 @@ def _build_input_queue( global_batch_size: int, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches return self._build_cifar_dataset(data_rng, split, @@ -74,34 +77,35 @@ def sync_batch_stats( # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') + avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x") new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + {"batch_stats": avg_fn(model_state["batch_stats"])}) return new_model_state def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate - model_cls = getattr(models, 'ResNet18') + model_cls = getattr(models, "ResNet18") model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({'params': rng}, + variables = jax.jit(model.init)({"params": rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = variables.pop("params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + # model_state = jax_utils.replicate(model_state) + # params = jax_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: - return param_key == 'Dense_0' + return param_key == "Dense_0" def model_fn( self, @@ -110,23 +114,26 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng - variables = {'params': params, **model_state} + variables = {"params": params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=["batch_stats"], + ) return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + augmented_and_preprocessed_input_batch["inputs"], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + ) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in @@ -136,13 +143,15 @@ def loss_fn( label_batch: spec.Tensor, # Dense or one-hot labels. logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ + Return {'summed': scalar summed loss, + 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( @@ -155,51 +164,69 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + "summed": summed_loss, + "n_valid_examples": n_valid_examples, + "per_example": per_example_losses, } def _compute_metrics(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor) -> Dict[str, spec.Tensor]: - summed_loss = self.loss_fn(labels, logits, weights)['summed'] + summed_loss = self.loss_fn(labels, logits, weights)["summed"] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) - metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, - } - metrics = lax.psum(metrics, axis_name='batch') - return metrics + return jnp.array(summed_loss), jnp.array(accuracy) - @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + sharding_utils.get_replicated_sharding(), # params + sharding_utils.get_naive_sharding_spec(), # batch + sharding_utils.get_replicated_sharding(), # model_state + sharding_utils.get_naive_sharding_spec(), # rng + ), + ) + def _per_device_eval_model( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + weights = batch.get("weights") + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch["targets"], weights) + + losses, accuracies = _per_device_eval_model(params, batch, model_state, rng) + metrics = { + "loss": + jnp.mean(losses, axis=0) if losses.ndim > 0 else losses, + "accuracy": + (jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies + ), + } + return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree_map(lambda x: x / num_examples, total_metrics)