Skip to content

Commit

Permalink
Merge pull request #490 from mlcommons/dev
Browse files Browse the repository at this point in the history
dev -> main
  • Loading branch information
priyakasimbeg authored Aug 30, 2023
2 parents aa0d692 + 4c38ffb commit c3273d9
Show file tree
Hide file tree
Showing 19 changed files with 380 additions and 237 deletions.
111 changes: 79 additions & 32 deletions .github/workflows/regression_tests.yml

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ To use the Docker container as an interactive virtual environment, you can run a
--gpus all \
--ipc=host \
<docker_image_name>
-keep_container_alive true
--keep_container_alive true
```
2. Open a bash terminal
```bash
Expand All @@ -148,8 +148,8 @@ python3 submission_runner.py \
--workload=mnist \
--experiment_dir=$HOME/experiments \
--experiment_name=my_first_experiment \
--submission_path=reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py \
--tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json
--submission_path=baselines/adamw/jax/submission.py \
--tuning_search_space=baselines/adamw/tuning_search_space.json
```

**Pytorch**
Expand All @@ -160,8 +160,8 @@ python3 submission_runner.py \
--workload=mnist \
--experiment_dir=$HOME/experiments \
--experiment_name=my_first_experiment \
--submission_path=reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py \
--tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json
--submission_path=baselines/adamw/jax/submission.py \
--tuning_search_space=baselines/adamw/tuning_search_space.json
```
<details>
<summary>
Expand All @@ -186,10 +186,10 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc
submission_runner.py \
--framework=pytorch \
--workload=mnist \
--experiment_dir=/home/znado \
--experiment_dir=$HOME/experiments \
--experiment_name=baseline \
--submission_path=reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py \
--tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json \
--submission_path=baselines/adamw/jax/submission.py \
--tuning_search_space=baselines/adamw/tuning_search_space.json
```
</details>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def dot_interact(concat_features):
"""
batch_size = concat_features.shape[0]

# Interact features, select upper or lower-triangular portion, and re-shape.
# Interact features, select upper or lower-triangular portion, and reshape.
xactions = jnp.matmul(concat_features,
jnp.transpose(concat_features, [0, 2, 1]))
feature_dim = xactions.shape[-1]
Expand Down Expand Up @@ -46,7 +46,7 @@ class DlrmSmall(nn.Module):
embed_dim: embedding dimension.
"""

vocab_size: int = 32 * 128 * 1024 # 4_194_304
vocab_size: int = 32 * 128 * 1024 # 4_194_304.
num_dense_features: int = 13
mlp_bottom_dims: Sequence[int] = (512, 256, 128)
mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Criteo1TB workload implemented in Jax."""

import functools
from typing import Dict, Optional, Tuple

Expand Down
102 changes: 54 additions & 48 deletions algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,23 @@
from torch import nn


def dot_interact(concat_features):
"""Performs feature interaction operation between dense or sparse features.
Input tensors represent dense or sparse features.
Pre-condition: The tensors have been stacked along dimension 1.
Args:
concat_features: Array of features with shape [B, n_features, feature_dim].
Returns:
activations: Array representing interacted features.
"""
batch_size = concat_features.shape[0]

# Interact features, select upper or lower-triangular portion, and re-shape.
xactions = torch.bmm(concat_features,
torch.permute(concat_features, (0, 2, 1)))
feature_dim = xactions.shape[-1]

indices = torch.triu_indices(feature_dim, feature_dim)
num_elems = indices.shape[1]
indices = torch.tile(indices, [1, batch_size])
indices0 = torch.reshape(
torch.tile(
torch.reshape(torch.arange(batch_size), [-1, 1]), [1, num_elems]),
[1, -1])
indices = tuple(torch.cat((indices0, indices), 0))
activations = xactions[indices]
activations = torch.reshape(activations, [batch_size, -1])
return activations
class DotInteract(nn.Module):
"""Performs feature interaction operation between dense or sparse features."""

def __init__(self, num_sparse_features):
super().__init__()
self.triu_indices = torch.triu_indices(num_sparse_features + 1,
num_sparse_features + 1)

def forward(self, dense_features, sparse_features):
combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
dim=1)
interactions = torch.bmm(combined_values,
torch.transpose(combined_values, 1, 2))
interactions_flat = interactions[:,
self.triu_indices[0],
self.triu_indices[1]]
return torch.cat((dense_features, interactions_flat), dim=1)


class DlrmSmall(nn.Module):
Expand Down Expand Up @@ -62,13 +52,21 @@ def __init__(self,
self.mlp_top_dims = mlp_top_dims
self.embed_dim = embed_dim

self.embedding_table = nn.Embedding(self.vocab_size, self.embed_dim)
self.embedding_table.weight.data.uniform_(0, 1)
# Scale the initialization to fan_in for each slice.
# Ideally, we should use the pooled embedding implementation from
# `TorchRec`. However, in order to have identical implementation
# with that of Jax, we define a single embedding matrix.
num_chucks = 4
assert vocab_size % num_chucks == 0
self.embedding_table_chucks = []
scale = 1.0 / torch.sqrt(self.vocab_size)
self.embedding_table.weight.data = scale * self.embedding_table.weight.data
for i in range(num_chucks):
chunk = nn.Parameter(
torch.Tensor(self.vocab_size // num_chucks, self.embed_dim))
chunk.data.uniform_(0, 1)
chunk.data = scale * chunk.data
self.register_parameter(f'embedding_chunk_{i}', chunk)
self.embedding_table_chucks.append(chunk)

# bottom mlp
bottom_mlp_layers = []
input_dim = self.num_dense_features
for dense_dim in self.mlp_bottom_dims:
Expand All @@ -84,8 +82,9 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))

# top mlp
# TODO (JB): Write down the formula here instead of the constant.
self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)

# TODO: Write down the formula here instead of the constant.
input_dims = 506
top_mlp_layers = []
num_layers_top = len(self.mlp_top_dims)
Expand All @@ -110,19 +109,26 @@ def __init__(self,
math.sqrt(1. / module.out_features))

def forward(self, x):
bot_mlp_input, cat_features = torch.split(
batch_size = x.shape[0]

dense_features, sparse_features = torch.split(
x, [self.num_dense_features, self.num_sparse_features], 1)
cat_features = cat_features.to(dtype=torch.int32)
bot_mlp_output = self.bot_mlp(bot_mlp_input)
batch_size = bot_mlp_output.shape[0]
feature_stack = torch.reshape(bot_mlp_output,
[batch_size, -1, self.embed_dim])
idx_lookup = torch.reshape(cat_features, [-1]) % self.vocab_size
embed_features = self.embedding_table(idx_lookup)
embed_features = torch.reshape(embed_features,
[batch_size, -1, self.embed_dim])
feature_stack = torch.cat([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(concat_features=feature_stack)
top_mlp_input = torch.cat([bot_mlp_output, dot_interact_output], axis=-1)
logits = self.top_mlp(top_mlp_input)

# Bottom MLP.
embedded_dense = self.bot_mlp(dense_features)

# Sparse feature processing.
sparse_features = sparse_features.to(dtype=torch.int32)
idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
embedded_sparse = embedding_table[idx_lookup]
embedded_sparse = torch.reshape(embedded_sparse,
[batch_size, -1, self.embed_dim])

# Dot product interactions.
concatenated_dense = self.dot_interact(
dense_features=embedded_dense, sparse_features=embedded_sparse)

# Final MLP.
logits = self.top_mlp(concatenated_dense)
return logits
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Criteo1TB workload implemented in PyTorch."""

import contextlib
from typing import Dict, Optional, Tuple
from typing import Dict, Iterator, Optional, Tuple

import jax
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
Expand All @@ -22,7 +22,7 @@ class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):

@property
def eval_batch_size(self) -> int:
return 262_144
return 32_768

def _per_example_sigmoid_binary_cross_entropy(
self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor:
Expand Down Expand Up @@ -66,11 +66,6 @@ def loss_fn(
'per_example': per_example_losses,
}

def _eval_metric(self, logits: spec.Tensor,
targets: spec.Tensor) -> Dict[str, int]:
summed_loss = self.loss_fn(logits, targets)['summed']
return {'loss': summed_loss}

def init_model_fn(
self,
rng: spec.RandomState,
Expand All @@ -79,6 +74,8 @@ def init_model_fn(
"""Only dropout is used."""
del aux_dropout_rate
torch.random.manual_seed(rng[0])
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
model = DlrmSmall(
vocab_size=self.vocab_size,
num_dense_features=self.num_dense_features,
Expand Down Expand Up @@ -130,25 +127,28 @@ def model_fn(

return logits_batch, None

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,
repeat_final_dataset: bool = False):
def _build_input_queue(
self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
not_train = split != 'train'
per_device_batch_size = int(global_batch_size / N_GPUS)

# Only create and iterate over tf input pipeline in one Python process to
# avoid creating too many threads.
if RANK == 0:
np_iter = super()._build_input_queue(data_rng,
split,
data_dir,
global_batch_size,
num_batches,
repeat_final_dataset)
np_iter = super()._build_input_queue(
data_rng=data_rng,
split=split,
data_dir=data_dir,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=repeat_final_dataset)
weights = None
while True:
if RANK == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
validation). See here for the NVIDIA example:
https://github.com/NVIDIA/DeepLearningExamples/blob/4e764dcd78732ebfe105fc05ea3dc359a54f6d5e/PyTorch/Recommendation/DLRM/preproc/run_spark_cpu.sh#L119.
"""

import functools
import os
from typing import Optional
Expand Down
47 changes: 27 additions & 20 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
"""Criteo1TB DLRM workload base class."""

import math
import os
from typing import Dict, Optional, Tuple
from typing import Dict, Iterator, Optional, Tuple

import jax
from absl import flags
import torch.distributed as dist

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.criteo1tb import input_pipeline

FLAGS = flags.FLAGS

USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ


class BaseCriteo1TbDlrmSmallWorkload(spec.Workload):
"""Criteo1tb workload."""

vocab_size: int = 32 * 128 * 1024 # 4_194_304
vocab_size: int = 32 * 128 * 1024 # 4_194_304.
num_dense_features: int = 13
mlp_bottom_dims: Tuple[int, int] = (512, 256, 128)
mlp_top_dims: Tuple[int, int, int] = (1024, 1024, 512, 256, 1)
Expand All @@ -26,14 +29,15 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'loss'

def has_reached_validation_target(self, eval_result: float) -> bool:
def has_reached_validation_target(self, eval_result: Dict[str,
float]) -> bool:
return eval_result['validation/loss'] < self.validation_target_value

@property
def validation_target_value(self) -> float:
return 0.123649

def has_reached_test_target(self, eval_result: float) -> bool:
def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
return eval_result['test/loss'] < self.test_target_value

@property
Expand Down Expand Up @@ -75,19 +79,22 @@ def train_stddev(self):

@property
def max_allowed_runtime_sec(self) -> int:
return 7703 # ~2 hours
return 7703 # ~2 hours.

@property
def eval_period_time_sec(self) -> int:
return 2 * 60

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,
repeat_final_dataset: bool = False):
return 2 * 600 # 20 mins.

def _build_input_queue(
self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
del cache
ds = input_pipeline.get_criteo1tb_dataset(
split=split,
shuffle_rng=data_rng,
Expand Down Expand Up @@ -121,11 +128,11 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators will repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
rng,
split,
data_dir,
global_batch_size,
num_batches,
data_rng=rng,
split=split,
data_dir=data_dir,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=True)
loss = 0.0
for _ in range(num_batches):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(self, config: DeepspeechConfig):

def forward(self, inputs, input_paddings):
inputs = self.bn(inputs, input_paddings)
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu()
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
inputs, lengths, batch_first=True, enforce_sorted=False)
packed_outputs, _ = self.lstm(packed_inputs)
Expand Down
Loading

0 comments on commit c3273d9

Please sign in to comment.