Skip to content

Commit

Permalink
squashing
Browse files Browse the repository at this point in the history
minor fixes, redoing how exp name works, setting preliminary eval_period_time-sec, need to set max runtime still, set updated target setting hparams

some fixes to criteo pytorch, doc touchups, minor fixes

removing initialized() in jax resnet imagenet

lint

batch mask key, fixes

Merge remote-tracking branch 'znado/naman-targets' into naman-targets

lint 2
  • Loading branch information
znado committed Oct 10, 2022
1 parent fa7359b commit 7ba0c4c
Show file tree
Hide file tree
Showing 44 changed files with 290 additions and 171 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ Docker is the easiest way to enable PyTorch/JAX GPU support on Linux since only
python3 submission_runner.py \
--framework=jax \
--workload=mnist \
--experiment_dir=/home/znado \
--submission_path=reference_submissions/mnist/mnist_jax/submission.py \
--tuning_search_space=reference_submissions/mnist/tuning_search_space.json
```
Expand All @@ -162,6 +163,7 @@ python3 submission_runner.py \
python3 submission_runner.py \
--framework=pytorch \
--workload=mnist \
--experiment_dir=/home/znado \
--submission_path=reference_submissions/mnist/mnist_pytorch/submission.py \
--tuning_search_space=reference_submissions/mnist/tuning_search_space.json
```
Expand All @@ -174,7 +176,10 @@ To do so, simply replace `python3` by
torchrun --standalone --nnodes=1 --nproc_per_node=N_GPUS
```

where `N_GPUS` is the number of available GPUs on the node.
where `N_GPUS` is the number of available GPUs on the node. To only see output from the first process, you can run the following to redirect the output from processes 1-7 to a log file:
```bash
torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8
```

## Rules

Expand Down
6 changes: 5 additions & 1 deletion RULES.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def model_fn(
###### Loss function

```python
def loss_fn(label_batch, logits_output_batch) -> 1d array of losses per example # differentiable
def loss_fn(
label_batch: Union[Tuple[Tensor, Tensor], Tensor],
logits_batch: Union[Tuple[Tensor, Tensor], Tensor],
mask_batch: Optional[Tensor] = None,
label_smoothing: float = 0.0) -> 1d array of losses per example # differentiable
```

- Unlike in the [Model Track](#model-track), we will specify the loss function name in order to let training algorithms depend on the loss function. It will be one of {**mean squared error**, **cross-entropy**, **CTC**, or **L1 reconstruction error**}.
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def _get_git_commit_hash() -> str:


def _get_git_branch() -> str:
return subprocess.check_output(['git', 'branch',
'--show-current']).decode('ascii').strip()
return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref',
'HEAD']).decode('ascii').strip()


def _get_cpu_model_name() -> str:
Expand Down
11 changes: 8 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,26 @@ def model_fn(
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
return logits, None
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(self,
label_batch: spec.Tensor,
logits_batch: spec.Tensor,
mask_batch: Optional[Tensor] = None,
label_smoothing: float = 0.0) -> spec.Tensor: # differentiable
one_hot_targets = jax.nn.one_hot(label_batch, 10)
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
return -jnp.sum(smoothed_targets * nn.log_softmax(logits_batch), axis=-1)
losses = -jnp.sum(smoothed_targets * nn.log_softmax(logits_batch), axis=-1)
# mask_batch is assumed to be shape [batch]
if mask_batch is not None:
losses *= mask_batch
return losses

def _compute_metrics(self, logits, labels):
loss = jnp.sum(self.loss_fn(labels, logits))
# not accuracy, but nr. of correct predictions
# Number of correct predictions.
accuracy = jnp.sum(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,17 @@ def output_activation_fn(self,
def loss_fn(self,
label_batch: spec.Tensor,
logits_batch: spec.Tensor,
mask_batch: Optional[Tensor] = None,
label_smoothing: float = 0.0) -> spec.Tensor: # differentiable
return F.cross_entropy(
losses = F.cross_entropy(
logits_batch,
label_batch,
reduction='none',
label_smoothing=label_smoothing)
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
losses *= mask_batch
return losses

def _eval_metric(self, logits, labels):
"""Return the mean accuracy and loss as a dict."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,11 @@ def loss_fn(
logits_batch: spec.Tensor,
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0) -> spec.Tensor:
# TODO(znado): confirm that we do not want to tune label smoothing here.
del label_smoothing
per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy(
logits=logits_batch, targets=label_batch)
if mask_batch is not None:
weighted_losses = per_example_losses * mask_batch
normalization = mask_batch.sum()
else:
weighted_losses = per_example_losses
normalization = label_batch.shape[0]
return jnp.sum(weighted_losses, axis=-1) / normalization
per_example_losses *= mask_batch
return per_example_losses

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
self._model = models.DlrmSmall(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,8 @@ def loss_fn(self,
per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy(
logits=logits_batch, targets=label_batch)
if mask_batch is not None:
weighted_losses = per_example_losses * mask_batch
normalization = mask_batch.sum()
else:
weighted_losses = per_example_losses
normalization = label_batch.shape[0]
return torch.sum(weighted_losses, dim=-1) / normalization
per_example_losses *= mask_batch
return per_example_losses

def _eval_metric(self, logits: spec.Tensor,
targets: spec.Tensor) -> Dict[str, int]:
Expand Down Expand Up @@ -81,9 +77,13 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
dropout_rate: Optional[float],
aux_dropout_rate: Optional[float],
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del dropout_rate
del aux_dropout_rate
del update_batch_norm

model = params
Expand Down Expand Up @@ -188,9 +188,11 @@ def _eval_batch(self, params, batch):
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
dropout_rate=None,
aux_dropout_rate=None,
update_batch_norm=False)
per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy(
logits, batch['targets'])
batch_loss_numerator = torch.sum(per_example_losses)
batch_loss_denominator = torch.sum(batch['weights'])
batch_loss_numerator = torch.sum(per_example_losses).cpu().numpy()
batch_loss_denominator = torch.sum(batch['weights']).cpu().numpy()
return batch_loss_numerator, batch_loss_denominator
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def has_reached_goal(self, eval_result: float) -> bool:

@property
def target_value(self):
return 0.1255
return 0.12422498

@property
def loss_type(self):
Expand Down Expand Up @@ -64,7 +64,7 @@ def max_allowed_runtime_sec(self):

@property
def eval_period_time_sec(self):
return 10 * 60
return 24 * 60

def output_activation_fn(self,
logits_batch: spec.Tensor,
Expand Down
17 changes: 13 additions & 4 deletions algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,19 @@ def output_activation_fn(self,

# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(self, label_batch: spec.Tensor,
outputs_batch: spec.Tensor) -> spec.Tensor: # differentiable
return jnp.abs(outputs_batch -
label_batch).mean(axis=tuple(range(1, outputs_batch.ndim)))
def loss_fn(self,
label_batch: spec.Tensor,
logits_batch: spec.Tensor,
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0) -> spec.Tensor: # differentiable
del label_smoothing
losses = jnp.sum(
jnp.abs(outputs_batch - label_batch),
axis=tuple(range(1, outputs_batch.ndim)))
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
losses *= mask_batch
return losses

@functools.partial(
jax.pmap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,15 @@ def output_activation_fn(self,
def loss_fn(self,
label_batch: spec.Tensor,
logits_batch: spec.Tensor,
mask_batch: spec.Tensor = None,
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0) -> spec.Tensor: # differentiable
del mask_batch
del label_smoothing
return F.l1_loss(
logits_batch, label_batch, reduction='none').mean(dim=(1, 2))
losses = F.l1_loss(
logits_batch, label_batch, reduction='none').sum(dim=(1, 2))
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
losses *= mask_batch
return losses

def _eval_model(self, params, batch, rng):
"""Return the SSIM and loss as a dict."""
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/fastmri/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def has_reached_goal(self, eval_result: float) -> bool:

@property
def target_value(self):
return 0.70
return 0.735102235

@property
def loss_type(self):
Expand Down Expand Up @@ -62,7 +62,7 @@ def max_allowed_runtime_sec(self):

@property
def eval_period_time_sec(self):
return 6000 # 100 mins
return 80

@property
def param_shapes(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
https://github.com/google/flax/blob/main/examples/imagenet/models.py
"""

from functools import partial
import functools
from typing import Any, Callable, Tuple

from flax import linen as nn
Expand Down Expand Up @@ -79,8 +79,8 @@ class ResNet(nn.Module):

@nn.compact
def __call__(self, x, update_batch_norm: bool = True):
conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
norm = partial(
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
momentum=0.9,
Expand Down Expand Up @@ -112,16 +112,18 @@ def __call__(self, x, update_batch_norm: bool = True):
return x


ResNet18 = partial(ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=(3, 4, 6, 3), block_cls=ResNetBlock)
ResNet50 = partial(
ResNet18 = functools.partial(
ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
ResNet34 = functools.partial(
ResNet, stage_sizes=(3, 4, 6, 3), block_cls=ResNetBlock)
ResNet50 = functools.partial(
ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock)
ResNet101 = partial(
ResNet101 = functools.partial(
ResNet, stage_sizes=(3, 4, 23, 3), block_cls=BottleneckResNetBlock)
ResNet152 = partial(
ResNet152 = functools.partial(
ResNet, stage_sizes=(3, 8, 36, 3), block_cls=BottleneckResNetBlock)
ResNet200 = partial(
ResNet200 = functools.partial(
ResNet, stage_sizes=(3, 24, 36, 3), block_cls=BottleneckResNetBlock)

# Used for testing only.
_ResNet1 = partial(ResNet, stage_sizes=(1,), block_cls=ResNetBlock)
_ResNet1 = functools.partial(ResNet, stage_sizes=(1,), block_cls=ResNetBlock)
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,14 @@ def model_params_types(self):
self._param_shapes.unfreeze())
return self._param_types

def initialized(self, key, model):
input_shape = (1, 224, 224, 3)
variables = jax.jit(model.init)({'params': key},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
return params, model_state

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model_cls = getattr(models, 'ResNet50')
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
self._model = model
params, model_state = self.initialized(rng, model)
input_shape = (1, 224, 224, 3)
variables = jax.jit(model.init)({'params': rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
params)
model_state = jax_utils.replicate(model_state)
Expand Down Expand Up @@ -158,13 +154,14 @@ def model_fn(
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
return logits, None
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(self,
label_batch: spec.Tensor,
logits_batch: spec.Tensor,
mask_batch: Optional[Tensor] = None,
label_smoothing: float = 0.0) -> spec.Tensor: # differentiable
"""Cross Entropy Loss"""
if label_batch.shape[-1] != self._num_classes:
Expand All @@ -173,8 +170,11 @@ def loss_fn(self,
else:
one_hot_labels = label_batch
smoothed_labels = optax.smooth_labels(one_hot_labels, label_smoothing)
return optax.softmax_cross_entropy(
logits=logits_batch, labels=smoothed_labels)
losses = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
losses *= mask_batch
return losses

def _compute_metrics(self, logits, labels):
loss = jnp.sum(self.loss_fn(labels, logits))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,17 @@ def output_activation_fn(self,
def loss_fn(self,
label_batch: spec.Tensor,
logits_batch: spec.Tensor,
mask_batch: Optional[Tensor] = None,
label_smoothing: float = 0.0) -> spec.Tensor: # differentiable
return F.cross_entropy(
losses = F.cross_entropy(
logits_batch,
label_batch,
reduction='none',
label_smoothing=label_smoothing)
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
losses *= mask_batch
return losses

def _eval_metric(self, logits, labels):
"""Return the mean accuracy and loss as a dict."""
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/imagenet_resnet/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def has_reached_goal(self, eval_result: float) -> bool:

@property
def target_value(self):
return 0.76
return 0.771850005

@property
def loss_type(self):
Expand Down Expand Up @@ -73,7 +73,7 @@ def max_allowed_runtime_sec(self):

@property
def eval_period_time_sec(self):
return 6000 # 100 mins
return 510 # 8.5 minutes.

@property
def param_shapes(self):
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/imagenet_vit/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ class BaseImagenetVitWorkload(BaseImagenetResNetWorkload):

@property
def target_value(self):
return 0.76
return 0.77171

@property
def max_allowed_runtime_sec(self):
return 111600 # 31 hours

@property
def eval_period_time_sec(self):
return 6000 # 100 mins
return 7 * 60 # 7 mins.

def _build_dataset(self,
data_rng: spec.RandomState,
Expand Down
Loading

0 comments on commit 7ba0c4c

Please sign in to comment.