Skip to content

Commit

Permalink
saving work refactoring target_setting_runs/
Browse files Browse the repository at this point in the history
  • Loading branch information
znado committed Oct 7, 2022
1 parent 517a442 commit fa7359b
Show file tree
Hide file tree
Showing 34 changed files with 325 additions and 834 deletions.
138 changes: 138 additions & 0 deletions target_setting_runs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Target Setting Run replications
Original runs were run on Google TPUv2-8 machines.

## Criteo
Target was set using AdamW with a linear warmup cosine decay LR schedule.
```bash
python3 submission_runner.py \
--framework=jax \
--workload=criteo1tb \
--submission_path=target_setting_runs/jax_adamw.py \
--tuning_search_space=target_setting_runs/criteo1tb/tuning_search_space.json
```
```bash
python3 submission_runner.py \
--framework=pytorch \
--workload=criteo1tb \
--submission_path=target_setting_runs/pytorch_adamw.py \
--tuning_search_space=target_setting_runs/criteo1tb/tuning_search_space.json
```

# FastMRI
Target was set using NAdamW with a linear warmup cosine decay LR schedule.
```bash
python3 submission_runner.py \
--framework=jax \
--workload=fastmri \
--submission_path=target_setting_runs/jax_nadamw.py \
--tuning_search_space=target_setting_runs/fastmri/tuning_search_space.json
```
```bash
python3 submission_runner.py \
--framework=pytorch \
--workload=fastmri \
--submission_path=target_setting_runs/pytorch_nadamw.py \
--tuning_search_space=target_setting_runs/fastmri/tuning_search_space.json
```

# ImageNet-Resnet
Target was set using Nesterov with a linear warmup and linear decay LR schedule.
```bash
python3 submission_runner.py \
--framework=jax \
--workload=imagenet_resnet \
--submission_path=target_setting_runs/jax_nesterov.py \
--tuning_search_space=target_setting_runs/imagenet_resnet/tuning_search_space.json
```
```bash
python3 submission_runner.py \
--framework=pytorch \
--workload=imagenet_resnet \
--submission_path=target_setting_runs/pytorch_nesterov.py \
--tuning_search_space=target_setting_runs/imagenet_resnet/tuning_search_space.json
```

# ImageNet-ViT
Target was set using NAdamW with a linear warmup cosine decay LR schedule.
```bash
python3 submission_runner.py \
--framework=jax \
--workload=imagenet_vit \
--submission_path=target_setting_runs/jax_nadamw.py \
--tuning_search_space=target_setting_runs/imagenet_vit/tuning_search_space.json
```
```bash
python3 submission_runner.py \
--framework=pytorch \
--workload=imagenet_vit \
--submission_path=target_setting_runs/pytorch_nadamw.py \
--tuning_search_space=target_setting_runs/imagenet_vit/tuning_search_space.json
```

# Librispeech-Conformer
Target was set using AdamW with a linear warmup cosine decay LR schedule.
```bash
python3 submission_runner.py \
--framework=jax \
--workload=librispeech_conformer \
--submission_path=target_setting_runs/jax_adamw.py \
--tuning_search_space=target_setting_runs/librispeech_conformer/tuning_search_space.json
```
```bash
python3 submission_runner.py \
--framework=pytorch \
--workload=librispeech_conformer \
--submission_path=target_setting_runs/pytorch_adamw.py \
--tuning_search_space=target_setting_runs/librispeech_conformer/tuning_search_space.json
```

# Librispeech-Deepspeech
Target was set using NAdamW with a linear warmup cosine decay LR schedule.
```bash
python3 submission_runner.py \
--framework=jax \
--workload=librispeech_deepspeech \
--submission_path=target_setting_runs/jax_nadamw.py \
--tuning_search_space=target_setting_runs/librispeech_deepspeech/tuning_search_space.json
```
```bash
python3 submission_runner.py \
--framework=pytorch \
--workload=librispeech_deepspeech \
--submission_path=target_setting_runs/pytorch_nadamw.py \
--tuning_search_space=target_setting_runs/librispeech_deepspeech/tuning_search_space.json
```

# OGBG
Target was set using Nesterov with a linear warmup and linear decay LR schedule.
```bash
python3 submission_runner.py \
--framework=jax \
--workload=ogbg \
--submission_path=target_setting_runs/jax_nesterov.py \
--tuning_search_space=target_setting_runs/ogbg/tuning_search_space.json
```
```bash
python3 submission_runner.py \
--framework=pytorch \
--workload=ogbg \
--submission_path=target_setting_runs/pytorch_nesterov.py \
--tuning_search_space=target_setting_runs/ogbg/tuning_search_space.json
```

# WMT
Target was set using AdamW with a linear warmup cosine decay LR schedule.
```bash
python3 submission_runner.py \
--framework=jax \
--workload=wmt \
--submission_path=target_setting_runs/jax_adamw.py \
--tuning_search_space=target_setting_runs/wmt/tuning_search_space.json
```
```bash
python3 submission_runner.py \
--framework=pytorch \
--workload=wmt \
--submission_path=target_setting_runs/pytorch_adamw.py \
--tuning_search_space=target_setting_runs/wmt/tuning_search_space.json
```
37 changes: 37 additions & 0 deletions target_setting_runs/cosine_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Implementions of a linear warmup then cosine decay LR schedule."""

import optax
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import LinearLR
from torch.optim.lr_scheduler import SequentialLR


def jax_cosine_warmup(hyperparameters):
# Create learning rate schedule.
warmup_fn = optax.linear_schedule(
init_value=0.,
end_value=hyperparameters.learning_rate,
transition_steps=hyperparameters.warmup_steps)
cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps,
1)
cosine_fn = optax.cosine_decay_schedule(
init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
schedule_fn = optax.join_schedules(
schedules=[warmup_fn, cosine_fn],
boundaries=[hyperparameters.warmup_steps])
return schedule_fn


def pytorch_cosine_warmup(hyperparameters, optimizer):
warmup = LinearLR(
optimizer,
start_factor=1e-10,
end_factor=1.,
total_iters=hyperparameters.warmup_steps)
cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps,
1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
optimizer,
schedulers=[warmup, cosine_decay],
milestones=[hyperparameters.warmup_steps])
Empty file.
Empty file.
2 changes: 1 addition & 1 deletion target_setting_runs/criteo1tb/tuning_search_space.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
},
"warmup_steps": {
"feasible_points": [
1600
200
]
},
"num_steps": {
Expand Down
Empty file.
Empty file.
23 changes: 23 additions & 0 deletions target_setting_runs/get_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Batch size selection submission function."""


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb_dlrm':
return 524288
elif workload_name == 'fastmri':
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
return 256
elif workload_name == 'librispeech_deepspeech':
return 256
elif workload_name == 'ogbg':
return 512
elif workload_name == 'wmt':
return 128
else:
raise ValueError(f'Unsupported workload name: {workload_name}.')
Empty file.
94 changes: 0 additions & 94 deletions target_setting_runs/imagenet_resnet/jax_submission.py

This file was deleted.

63 changes: 0 additions & 63 deletions target_setting_runs/imagenet_resnet/pytorch_submission.py

This file was deleted.

Empty file.
Loading

0 comments on commit fa7359b

Please sign in to comment.