Skip to content

Commit

Permalink
Merge branch 'dev' into feat/pt2-compile-loss-fn
Browse files Browse the repository at this point in the history
  • Loading branch information
BoyuanFeng authored Dec 8, 2023
2 parents 33ecf8e + 64d1a85 commit a134d08
Show file tree
Hide file tree
Showing 41 changed files with 1,682 additions and 159 deletions.
85 changes: 85 additions & 0 deletions .github/workflows/regression_tests_variants.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
name: Containerized Regression Tests for Workload Variants

on:
pull_request:
branches:
- 'main'

jobs:
build_and_push_jax_docker_image:
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- name: Build and push docker images
run: |
GIT_BRANCH=${{ github.head_ref || github.ref_name }}
FRAMEWORK=jax
IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}"
cd $HOME/algorithmic-efficiency/docker
docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH
BUILD_RETURN=$?
if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi
docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
build_and_push_pytorch_docker_image:
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- name: Build and push docker images
run: |
GIT_BRANCH=${{ github.head_ref || github.ref_name }}
FRAMEWORK=pytorch
IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}"
cd $HOME/algorithmic-efficiency/docker
docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH
BUILD_RETURN=$?
if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi
docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
criteo_layernorm_jax:
runs-on: self-hosted
needs: build_and_push_jax_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
criteo_resnet_jax:
runs-on: self-hosted
needs: build_and_push_jax_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
criteo_layernorm_pytorch:
runs-on: self-hosted
needs: build_and_push_pytorch_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
criteo_resnet_pytorch:
runs-on: self-hosted
needs: build_and_push_pytorch_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
criteo_resnet_pytorch:
runs-on: self-hosted
needs: build_and_push_pytorch_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_embed_init -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
111 changes: 109 additions & 2 deletions algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,101 @@
import jax.numpy as jnp


class DLRMResNet(nn.Module):
"""Define a DLRMResNet model.
Parameters:
vocab_size: the size of a single unified embedding table.
mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
mlp_top_dims: dimensions of dense layers of the top mlp.
num_dense_features: number of dense features as the bottom mlp input.
embed_dim: embedding dimension.
"""

vocab_size: int = 32 * 128 * 1024 # 4_194_304
num_dense_features: int = 13
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
embed_dim: int = 128
dropout_rate: float = 0.0
use_layer_norm: bool = False # Unused.
embedding_init_multiplier: float = None # Unused

@nn.compact
def __call__(self, x, train):
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)

# bottom mlp
mlp_bottom_dims = self.mlp_bottom_dims

bot_mlp_input = nn.Dense(
mlp_bottom_dims[0],
kernel_init=jnn.initializers.glorot_uniform(),
bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5),
)(
bot_mlp_input)
bot_mlp_input = nn.relu(bot_mlp_input)

for dense_dim in mlp_bottom_dims[1:]:
x = nn.Dense(
dense_dim,
kernel_init=jnn.initializers.glorot_uniform(),
bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
)(
bot_mlp_input)
bot_mlp_input += nn.relu(x)

base_init_fn = jnn.initializers.uniform(scale=1.0)
# Embedding table init and lookup for a single unified table.
idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size

def scaled_init(key, shape, dtype=jnp.float_):
return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size)

embedding_table = self.param('embedding_table',
scaled_init, [self.vocab_size, self.embed_dim])

embed_features = embedding_table[idx_lookup]
batch_size = bot_mlp_input.shape[0]
embed_features = jnp.reshape(embed_features,
(batch_size, 26 * self.embed_dim))
top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1)
mlp_input_dim = top_mlp_input.shape[1]
mlp_top_dims = self.mlp_top_dims
num_layers_top = len(mlp_top_dims)
top_mlp_input = nn.Dense(
mlp_top_dims[0],
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))(
top_mlp_input)
top_mlp_input = nn.relu(top_mlp_input)
for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]:
fan_in = mlp_top_dims[layer_idx - 1]
x = nn.Dense(
fan_out,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
top_mlp_input)
x = nn.relu(x)
if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2:
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
top_mlp_input += x
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
logits = nn.Dense(
1,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))),
bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))(
top_mlp_input)
return logits


def dot_interact(concat_features):
"""Performs feature interaction operation between dense or sparse features.
Input tensors represent dense or sparse features.
Expand Down Expand Up @@ -52,6 +147,8 @@ class DlrmSmall(nn.Module):
mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1)
embed_dim: int = 128
dropout_rate: float = 0.0
use_layer_norm: bool = False
embedding_init_multiplier: float = None

@nn.compact
def __call__(self, x, train):
Expand All @@ -67,6 +164,8 @@ def __call__(self, x, train):
)(
bot_mlp_input)
bot_mlp_input = nn.relu(bot_mlp_input)
if self.use_layer_norm:
bot_mlp_input = nn.LayerNorm()(bot_mlp_input)
bot_mlp_output = bot_mlp_input
batch_size = bot_mlp_output.shape[0]
feature_stack = jnp.reshape(bot_mlp_output,
Expand All @@ -75,9 +174,13 @@ def __call__(self, x, train):
# Embedding table look-up.
idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size

if self.embedding_init_multiplier is None:
scale = 1 / jnp.sqrt(self.vocab_size)
else:
scale = self.embedding_init_multiplier

def scaled_init(key, shape, dtype=jnp.float_):
return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) /
jnp.sqrt(self.vocab_size))
return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale

embedding_table = self.param('embedding_table',
scaled_init, [self.vocab_size, self.embed_dim])
Expand All @@ -86,6 +189,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
embed_features = embedding_table[idx_lookup]
embed_features = jnp.reshape(embed_features,
[batch_size, -1, self.embed_dim])
if self.use_layer_norm:
embed_features = nn.LayerNorm()(embed_features)
feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(concat_features=feature_stack)
top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output],
Expand All @@ -103,6 +208,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input)
if layer_idx < (num_layers_top - 1):
top_mlp_input = nn.relu(top_mlp_input)
if self.use_layer_norm:
top_mlp_input = nn.LayerNorm()(top_mlp_input)
if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
top_mlp_input = nn.Dropout(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,31 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
aux_dropout_rate: Optional[float] = None,
tabulate: Optional[bool] = False,
) -> spec.ModelInitState:
"""Only dropout is used."""
del aux_dropout_rate
self._model = models.DlrmSmall(
if self.use_resnet:
model_class = models.DLRMResNet
else:
model_class = models.DlrmSmall
self._model = model_class(
vocab_size=self.vocab_size,
num_dense_features=self.num_dense_features,
mlp_bottom_dims=self.mlp_bottom_dims,
mlp_top_dims=self.mlp_top_dims,
embed_dim=self.embed_dim,
dropout_rate=dropout_rate)
dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)

params_rng, dropout_rng = jax.random.split(rng)
init_fake_batch_size = 2
num_categorical_features = 26
input_size = self.num_dense_features + num_categorical_features
num_dense_features = 13
input_size = num_dense_features + num_categorical_features
input_shape = (init_fake_batch_size, input_size)

init_fn = functools.partial(self._model.init, train=False)
initial_variables = jax.jit(init_fn)(
{'params': params_rng, 'dropout': dropout_rng},
Expand Down Expand Up @@ -154,3 +162,53 @@ def _eval_batch(self,

class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
vocab_size: int = 32 * 128 * 16


class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload):

@property
def use_layer_norm(self) -> bool:
"""Whether or not to use LayerNorm in the model."""
return True

@property
def validation_target_value(self) -> float:
return 0.123744

@property
def test_target_value(self) -> float:
return 0.126152


class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload):
mlp_bottom_dims: Tuple[int, int] = (256, 256, 256)
mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1)

@property
def use_resnet(self) -> bool:
"""Whether or not to use residual connections in the model."""
return True

@property
def validation_target_value(self) -> float:
return 0.124027

@property
def test_target_value(self) -> float:
return 0.126468


class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload):

@property
def validation_target_value(self) -> float:
return 0.124286

@property
def test_target_value(self) -> float:
# Todo
return 0.126725

@property
def embedding_init_multiplier(self) -> float:
return 1.0
Loading

0 comments on commit a134d08

Please sign in to comment.