Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] JAX.JIT Switch and Sharding #822

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions algorithmic_efficiency/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,7 @@ def save_checkpoint(framework: str,
train_state, eval_results, global_step, preemption_count).
"""
if framework == 'jax':
model_params = jax.device_get(jax_utils.unreplicate(model_params))
opt_state, _ = optimizer_state
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(
model_params,
Expand Down
5 changes: 1 addition & 4 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ def _prepare(x):
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))
return x

return jax.tree_map(_prepare, batch)

Expand Down
61 changes: 61 additions & 0 deletions algorithmic_efficiency/sharding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Utilities for dealing with sharding in JAX."""

import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec


def get_mesh() -> jax.sharding.Mesh:
"""Creates a mesh from all available GPUs. Here, we simply create a one-dimensional mesh."""
return jax.sharding.Mesh(jax.devices(), ("batch",))


def get_replicated_sharding(mesh=None):
"""Returns a sharding spec that replicates data across all devices."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec())


def get_naive_sharding_spec(mesh=None):
"""Returns a sharding spec that shards data along the first axis."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec("batch"))


def get_naive_sharding(x, mesh=None):
"""Given a 1D mesh and a tensor, try to shard along the appropriate axis."""
if mesh is None:
mesh = get_mesh()
grid_size = mesh.shape["batch"]
if x.shape[0] % grid_size == 0:
return NamedSharding(mesh, PartitionSpec("batch"))
else:
return NamedSharding(mesh, PartitionSpec())


def shard_params(params, mesh=None):
"""Shards a parameter tree across all devices with naive sharding (see get_naive_sharding)."""
if mesh is None:
mesh = get_mesh()
return jax.tree_util.tree_map(
lambda x: jax.device_put(x, get_naive_sharding(x)), params)


def get_sharding_tree(params, mesh=None):
"""Returns a sharding tree for a parameter tree."""
return jax.tree_util.tree_map(lambda x: get_naive_sharding(x, mesh), params)


def get_empty_sharding(mesh=None):
"""Returns a sharding spec that replicates data across all devices."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec())


def disp_shard_info(x: jax.Array):
"""Displays shard info of a jax array."""
for shard in x.addressable_shards:
print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:"
f" {shard.replica_id}.\n")
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import functools
from typing import Dict, Iterator, Tuple

from flax import jax_utils
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -171,5 +170,6 @@ def create_input_iter(
functools.partial(
shard_and_maybe_pad_np, global_batch_size=global_batch_size),
ds)
it = jax_utils.prefetch_to_device(it, 2)
# FIXME(rka97): Figure out how to do prefetching+sharding.
# it = jax_utils.prefetch_to_device(it, 2)
return it
143 changes: 85 additions & 58 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import functools
from typing import Any, Dict, Iterator, Optional, Tuple

from flax import jax_utils
from flax import linen as nn
import jax
from jax import lax
Expand All @@ -12,6 +11,7 @@
import tensorflow_datasets as tfds

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import sharding_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.cifar.cifar_jax import models
from algorithmic_efficiency.workloads.cifar.cifar_jax.input_pipeline import \
Expand All @@ -28,15 +28,16 @@ def _build_cifar_dataset(
data_dir: str,
batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None
repeat_final_dataset: Optional[bool] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
train = split == 'train'
data_dir = data_dir + "/cifar10"
ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir)
train = split == "train"
assert self.num_train_examples + self.num_validation_examples == 50000
if split in ['train', 'eval_train']:
split = f'train[:{self.num_train_examples}]'
elif split == 'validation':
split = f'train[{self.num_train_examples}:]'
if split in ["train", "eval_train"]:
split = f"train[:{self.num_train_examples}]"
elif split == "validation":
split = f"train[{self.num_train_examples}:]"
ds = create_input_iter(
split,
ds_builder,
Expand All @@ -48,7 +49,8 @@ def _build_cifar_dataset(
self.padding_size,
train=train,
cache=not train if cache is None else cache,
repeat_final_dataset=repeat_final_dataset)
repeat_final_dataset=repeat_final_dataset,
)
return ds

def _build_input_queue(
Expand All @@ -59,7 +61,8 @@ def _build_input_queue(
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
num_batches: Optional[int] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
return self._build_cifar_dataset(data_rng,
split,
Expand All @@ -74,34 +77,35 @@ def sync_batch_stats(
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics
# and we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
avg_fn = jax.pmap(lambda x: lax.pmean(x, "x"), "x")
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
{"batch_stats": avg_fn(model_state["batch_stats"])})
return new_model_state

def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
aux_dropout_rate: Optional[float] = None,
) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
model_cls = getattr(models, 'ResNet18')
model_cls = getattr(models, "ResNet18")
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
self._model = model
input_shape = (1, 32, 32, 3)
variables = jax.jit(model.init)({'params': rng},
variables = jax.jit(model.init)({"params": rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
model_state, params = variables.pop("params")
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
params = jax_utils.replicate(params)
# model_state = jax_utils.replicate(model_state)
# params = jax_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
return param_key == "Dense_0"

def model_fn(
self,
Expand All @@ -110,23 +114,26 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
variables = {"params": params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
augmented_and_preprocessed_input_batch["inputs"],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=["batch_stats"],
)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
augmented_and_preprocessed_input_batch["inputs"],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand All @@ -136,13 +143,15 @@ def loss_fn(
label_batch: spec.Tensor, # Dense or one-hot labels.
logits_batch: spec.Tensor,
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
label_smoothing: float = 0.0,
) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).

Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
Return {'summed': scalar summed loss,
'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes)
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
per_example_losses = -jnp.sum(
Expand All @@ -155,51 +164,69 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
'summed': summed_loss,
'n_valid_examples': n_valid_examples,
'per_example': per_example_losses,
"summed": summed_loss,
"n_valid_examples": n_valid_examples,
"per_example": per_example_losses,
}

def _compute_metrics(self,
logits: spec.Tensor,
labels: spec.Tensor,
weights: spec.Tensor) -> Dict[str, spec.Tensor]:
summed_loss = self.loss_fn(labels, logits, weights)['summed']
summed_loss = self.loss_fn(labels, logits, weights)["summed"]
# Number of correct predictions.
accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights)
metrics = {
'loss': summed_loss,
'accuracy': accuracy,
}
metrics = lax.psum(metrics, axis_name='batch')
return metrics
return jnp.array(summed_loss), jnp.array(accuracy)

@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0, 0, None),
static_broadcasted_argnums=(0,))
def _eval_model(
self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
rng: spec.RandomState,
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False)
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)

@functools.partial(
jax.jit,
in_shardings=(
sharding_utils.get_replicated_sharding(), # params
sharding_utils.get_naive_sharding_spec(), # batch
sharding_utils.get_replicated_sharding(), # model_state
sharding_utils.get_naive_sharding_spec(), # rng
),
)
def _per_device_eval_model(
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False,
)
weights = batch.get("weights")
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch["targets"], weights)

losses, accuracies = _per_device_eval_model(params, batch, model_state, rng)
metrics = {
"loss":
jnp.mean(losses, axis=0) if losses.ndim > 0 else losses,
"accuracy":
(jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies
),
}
return metrics

def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree_map(lambda x: x / num_examples, total_metrics)
Loading
Loading