Skip to content

Commit

Permalink
Merge pull request #748 from mlcommons/variant_bsz_fix
Browse files Browse the repository at this point in the history
override imagenet_resnet silu and gelu bsz in configs
  • Loading branch information
priyakasimbeg authored Apr 2, 2024
2 parents e9bdbd8 + e203a42 commit d329e42
Show file tree
Hide file tree
Showing 17 changed files with 66 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
4 changes: 4 additions & 0 deletions reference_algorithms/paper_baselines/adamw/jax/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
4 changes: 4 additions & 0 deletions reference_algorithms/paper_baselines/nadamw/jax/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def get_batch_size(workload_name):
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
Expand Down
3 changes: 2 additions & 1 deletion submissions/template/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def get_batch_size(workload_name):
Args:
workload_name (str): Valid workload_name values are: "wmt", "ogbg",
"criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit",
"librispeech_deepspeech", "librispeech_conformer".
"librispeech_deepspeech", "librispeech_conformer" or any of the
variants.
Returns:
int: batch_size
Raises:
Expand Down

0 comments on commit d329e42

Please sign in to comment.