Skip to content

Commit

Permalink
CIFAR workload sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
rka97 committed Dec 9, 2024
1 parent ef6af03 commit e6037d6
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import functools
from typing import Dict, Iterator, Tuple

from flax import jax_utils
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -171,5 +170,6 @@ def create_input_iter(
functools.partial(
shard_and_maybe_pad_np, global_batch_size=global_batch_size),
ds)
it = jax_utils.prefetch_to_device(it, 2)
# FIXME(rka97): Figure out how to do prefetching+sharding.
# it = jax_utils.prefetch_to_device(it, 2)
return it
143 changes: 85 additions & 58 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import functools
from typing import Any, Dict, Iterator, Optional, Tuple

from flax import jax_utils
from flax import linen as nn
import jax
from jax import lax
Expand All @@ -12,6 +11,7 @@
import tensorflow_datasets as tfds

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import sharding_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.cifar.cifar_jax import models
from algorithmic_efficiency.workloads.cifar.cifar_jax.input_pipeline import \
Expand All @@ -28,15 +28,16 @@ def _build_cifar_dataset(
data_dir: str,
batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None
repeat_final_dataset: Optional[bool] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
train = split == 'train'
data_dir = data_dir + "/cifar10"
ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir)
train = split == "train"
assert self.num_train_examples + self.num_validation_examples == 50000
if split in ['train', 'eval_train']:
split = f'train[:{self.num_train_examples}]'
elif split == 'validation':
split = f'train[{self.num_train_examples}:]'
if split in ["train", "eval_train"]:
split = f"train[:{self.num_train_examples}]"
elif split == "validation":
split = f"train[{self.num_train_examples}:]"
ds = create_input_iter(
split,
ds_builder,
Expand All @@ -48,7 +49,8 @@ def _build_cifar_dataset(
self.padding_size,
train=train,
cache=not train if cache is None else cache,
repeat_final_dataset=repeat_final_dataset)
repeat_final_dataset=repeat_final_dataset,
)
return ds

def _build_input_queue(
Expand All @@ -59,7 +61,8 @@ def _build_input_queue(
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
num_batches: Optional[int] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
return self._build_cifar_dataset(data_rng,
split,
Expand All @@ -74,34 +77,35 @@ def sync_batch_stats(
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics
# and we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x")
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
{"batch_stats": avg_fn(model_state["batch_stats"])})
return new_model_state

def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
aux_dropout_rate: Optional[float] = None,
) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
model_cls = getattr(models, 'ResNet18')
model_cls = getattr(models, "ResNet18")
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
self._model = model
input_shape = (1, 32, 32, 3)
variables = jax.jit(model.init)({'params': rng},
variables = jax.jit(model.init)({"params": rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
model_state, params = variables.pop("params")
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
params = jax_utils.replicate(params)
# model_state = jax_utils.replicate(model_state)
# params = jax_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
return param_key == "Dense_0"

def model_fn(
self,
Expand All @@ -110,23 +114,26 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
variables = {"params": params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
augmented_and_preprocessed_input_batch["inputs"],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=["batch_stats"],
)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
augmented_and_preprocessed_input_batch["inputs"],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand All @@ -136,13 +143,15 @@ def loss_fn(
label_batch: spec.Tensor, # Dense or one-hot labels.
logits_batch: spec.Tensor,
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
label_smoothing: float = 0.0,
) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
Return {'summed': scalar summed loss,
'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes)
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
per_example_losses = -jnp.sum(
Expand All @@ -155,51 +164,69 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
'summed': summed_loss,
'n_valid_examples': n_valid_examples,
'per_example': per_example_losses,
"summed": summed_loss,
"n_valid_examples": n_valid_examples,
"per_example": per_example_losses,
}

def _compute_metrics(self,
logits: spec.Tensor,
labels: spec.Tensor,
weights: spec.Tensor) -> Dict[str, spec.Tensor]:
summed_loss = self.loss_fn(labels, logits, weights)['summed']
summed_loss = self.loss_fn(labels, logits, weights)["summed"]
# Number of correct predictions.
accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights)
metrics = {
'loss': summed_loss,
'accuracy': accuracy,
}
metrics = lax.psum(metrics, axis_name='batch')
return metrics
return jnp.array(summed_loss), jnp.array(accuracy)

@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0, 0, None),
static_broadcasted_argnums=(0,))
def _eval_model(
self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
rng: spec.RandomState,
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False)
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)

@functools.partial(
jax.jit,
in_shardings=(
sharding_utils.get_replicated_sharding(), # params
sharding_utils.get_naive_sharding_spec(), # batch
sharding_utils.get_replicated_sharding(), # model_state
sharding_utils.get_naive_sharding_spec(), # rng
),
)
def _per_device_eval_model(
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False,
)
weights = batch.get("weights")
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch["targets"], weights)

losses, accuracies = _per_device_eval_model(params, batch, model_state, rng)
metrics = {
"loss":
jnp.mean(losses, axis=0) if losses.ndim > 0 else losses,
"accuracy":
(jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies
),
}
return metrics

def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree_map(lambda x: x / num_examples, total_metrics)

0 comments on commit e6037d6

Please sign in to comment.