Skip to content

Commit

Permalink
Add docker image
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Dec 9, 2023
1 parent c8a9e72 commit ecf8220
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import math

import tensorflow as tf

# from tensorflow_addons import image as contrib_image
from tensorflow_addons import image as contrib_image

# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
Expand Down
12 changes: 12 additions & 0 deletions algorithmic_efficiency/workloads/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@
'workload_path': 'imagenet_vit/imagenet',
'workload_class_name': 'ImagenetVitWorkload',
},
'imagenet_vit_glu': {
'workload_path': 'imagenet_vit/imagenet',
'workload_class_name': 'ImagenetVitGluWorkload',
},
'imagenet_vit_post_ln': {
'workload_path': 'imagenet_vit/imagenet',
'workload_class_name': 'ImagenetViTPostLNWorkload',
},
'imagenet_vit_map': {
'workload_path': 'imagenet_vit/imagenet',
'workload_class_name': 'ImagenetViTMapLNWorkload',
},
'librispeech_conformer': {
'workload_path': 'librispeech_conformer/librispeech',
'workload_class_name': 'LibriSpeechConformerWorkload',
Expand Down
3 changes: 2 additions & 1 deletion docker/scripts/startup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ done
VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \
"wmt" "mnist")
VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \
"imagenet_resnet_large_bn_init" "imagenet_vit" "fastmri" "ogbg" \
"imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \
"imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \
"wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist" \
"criteo1tb_resnet" "criteo1tb_layernorm" "criteo_embed_init" \
"conformer_layernorm" "conformer_attention_temperature" \
Expand Down
3 changes: 1 addition & 2 deletions tests/modeldiffs/imagenet_vit/compare.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os

from tests.modeldiffs.diff import out_diff

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

Expand All @@ -13,6 +11,7 @@
ImagenetVitWorkload as JaxWorkload
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \
ImagenetVitWorkload as PytWorkload
from tests.modeldiffs.diff import out_diff


def key_transform(k):
Expand Down

0 comments on commit ecf8220

Please sign in to comment.