-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
saving work refactoring target_setting_runs/
- Loading branch information
Showing
34 changed files
with
325 additions
and
834 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Target Setting Run replications | ||
Original runs were run on Google TPUv2-8 machines. | ||
|
||
## Criteo | ||
Target was set using AdamW with a linear warmup cosine decay LR schedule. | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=jax \ | ||
--workload=criteo1tb \ | ||
--submission_path=target_setting_runs/jax_adamw.py \ | ||
--tuning_search_space=target_setting_runs/criteo1tb/tuning_search_space.json | ||
``` | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=pytorch \ | ||
--workload=criteo1tb \ | ||
--submission_path=target_setting_runs/pytorch_adamw.py \ | ||
--tuning_search_space=target_setting_runs/criteo1tb/tuning_search_space.json | ||
``` | ||
|
||
# FastMRI | ||
Target was set using NAdamW with a linear warmup cosine decay LR schedule. | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=jax \ | ||
--workload=fastmri \ | ||
--submission_path=target_setting_runs/jax_nadamw.py \ | ||
--tuning_search_space=target_setting_runs/fastmri/tuning_search_space.json | ||
``` | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=pytorch \ | ||
--workload=fastmri \ | ||
--submission_path=target_setting_runs/pytorch_nadamw.py \ | ||
--tuning_search_space=target_setting_runs/fastmri/tuning_search_space.json | ||
``` | ||
|
||
# ImageNet-Resnet | ||
Target was set using Nesterov with a linear warmup and linear decay LR schedule. | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=jax \ | ||
--workload=imagenet_resnet \ | ||
--submission_path=target_setting_runs/jax_nesterov.py \ | ||
--tuning_search_space=target_setting_runs/imagenet_resnet/tuning_search_space.json | ||
``` | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=pytorch \ | ||
--workload=imagenet_resnet \ | ||
--submission_path=target_setting_runs/pytorch_nesterov.py \ | ||
--tuning_search_space=target_setting_runs/imagenet_resnet/tuning_search_space.json | ||
``` | ||
|
||
# ImageNet-ViT | ||
Target was set using NAdamW with a linear warmup cosine decay LR schedule. | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=jax \ | ||
--workload=imagenet_vit \ | ||
--submission_path=target_setting_runs/jax_nadamw.py \ | ||
--tuning_search_space=target_setting_runs/imagenet_vit/tuning_search_space.json | ||
``` | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=pytorch \ | ||
--workload=imagenet_vit \ | ||
--submission_path=target_setting_runs/pytorch_nadamw.py \ | ||
--tuning_search_space=target_setting_runs/imagenet_vit/tuning_search_space.json | ||
``` | ||
|
||
# Librispeech-Conformer | ||
Target was set using AdamW with a linear warmup cosine decay LR schedule. | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=jax \ | ||
--workload=librispeech_conformer \ | ||
--submission_path=target_setting_runs/jax_adamw.py \ | ||
--tuning_search_space=target_setting_runs/librispeech_conformer/tuning_search_space.json | ||
``` | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=pytorch \ | ||
--workload=librispeech_conformer \ | ||
--submission_path=target_setting_runs/pytorch_adamw.py \ | ||
--tuning_search_space=target_setting_runs/librispeech_conformer/tuning_search_space.json | ||
``` | ||
|
||
# Librispeech-Deepspeech | ||
Target was set using NAdamW with a linear warmup cosine decay LR schedule. | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=jax \ | ||
--workload=librispeech_deepspeech \ | ||
--submission_path=target_setting_runs/jax_nadamw.py \ | ||
--tuning_search_space=target_setting_runs/librispeech_deepspeech/tuning_search_space.json | ||
``` | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=pytorch \ | ||
--workload=librispeech_deepspeech \ | ||
--submission_path=target_setting_runs/pytorch_nadamw.py \ | ||
--tuning_search_space=target_setting_runs/librispeech_deepspeech/tuning_search_space.json | ||
``` | ||
|
||
# OGBG | ||
Target was set using Nesterov with a linear warmup and linear decay LR schedule. | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=jax \ | ||
--workload=ogbg \ | ||
--submission_path=target_setting_runs/jax_nesterov.py \ | ||
--tuning_search_space=target_setting_runs/ogbg/tuning_search_space.json | ||
``` | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=pytorch \ | ||
--workload=ogbg \ | ||
--submission_path=target_setting_runs/pytorch_nesterov.py \ | ||
--tuning_search_space=target_setting_runs/ogbg/tuning_search_space.json | ||
``` | ||
|
||
# WMT | ||
Target was set using AdamW with a linear warmup cosine decay LR schedule. | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=jax \ | ||
--workload=wmt \ | ||
--submission_path=target_setting_runs/jax_adamw.py \ | ||
--tuning_search_space=target_setting_runs/wmt/tuning_search_space.json | ||
``` | ||
```bash | ||
python3 submission_runner.py \ | ||
--framework=pytorch \ | ||
--workload=wmt \ | ||
--submission_path=target_setting_runs/pytorch_adamw.py \ | ||
--tuning_search_space=target_setting_runs/wmt/tuning_search_space.json | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Implementions of a linear warmup then cosine decay LR schedule.""" | ||
|
||
import optax | ||
from torch.optim.lr_scheduler import CosineAnnealingLR | ||
from torch.optim.lr_scheduler import LinearLR | ||
from torch.optim.lr_scheduler import SequentialLR | ||
|
||
|
||
def jax_cosine_warmup(hyperparameters): | ||
# Create learning rate schedule. | ||
warmup_fn = optax.linear_schedule( | ||
init_value=0., | ||
end_value=hyperparameters.learning_rate, | ||
transition_steps=hyperparameters.warmup_steps) | ||
cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps, | ||
1) | ||
cosine_fn = optax.cosine_decay_schedule( | ||
init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) | ||
schedule_fn = optax.join_schedules( | ||
schedules=[warmup_fn, cosine_fn], | ||
boundaries=[hyperparameters.warmup_steps]) | ||
return schedule_fn | ||
|
||
|
||
def pytorch_cosine_warmup(hyperparameters, optimizer): | ||
warmup = LinearLR( | ||
optimizer, | ||
start_factor=1e-10, | ||
end_factor=1., | ||
total_iters=hyperparameters.warmup_steps) | ||
cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps, | ||
1) | ||
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) | ||
return SequentialLR( | ||
optimizer, | ||
schedulers=[warmup, cosine_decay], | ||
milestones=[hyperparameters.warmup_steps]) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
}, | ||
"warmup_steps": { | ||
"feasible_points": [ | ||
1600 | ||
200 | ||
] | ||
}, | ||
"num_steps": { | ||
|
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Batch size selection submission function.""" | ||
|
||
|
||
def get_batch_size(workload_name): | ||
# Return the global batch size. | ||
if workload_name == 'criteo1tb_dlrm': | ||
return 524288 | ||
elif workload_name == 'fastmri': | ||
return 32 | ||
elif workload_name == 'imagenet_resnet': | ||
return 1024 | ||
elif workload_name == 'imagenet_vit': | ||
return 1024 | ||
elif workload_name == 'librispeech_conformer': | ||
return 256 | ||
elif workload_name == 'librispeech_deepspeech': | ||
return 256 | ||
elif workload_name == 'ogbg': | ||
return 512 | ||
elif workload_name == 'wmt': | ||
return 128 | ||
else: | ||
raise ValueError(f'Unsupported workload name: {workload_name}.') |
Empty file.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Empty file.
Oops, something went wrong.