Skip to content

Commit

Permalink
turn off torch compile
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Sep 22, 2023
1 parent 77c1d60 commit cd7ee84
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 9 deletions.
4 changes: 0 additions & 4 deletions baselines/nadamw/pytorch/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def __init__(self,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
fused=False,
foreach=False,
):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
Expand All @@ -63,7 +61,6 @@ def __init__(self,
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay,
'fused': fused, 'foreach': foreach,
}
super().__init__(params, defaults)

Expand Down Expand Up @@ -211,7 +208,6 @@ def init_optimizer_state(workload: spec.Workload,
hyperparameters.beta2),
eps=1e-8,
weight_decay=hyperparameters.weight_decay,
fused=False,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def __init__(self,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
fused=False,
foreach=False,
):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
Expand All @@ -60,7 +58,6 @@ def __init__(self,
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay,
'fused': fused, 'foreach': foreach,
}
super().__init__(params, defaults)

Expand Down Expand Up @@ -208,7 +205,7 @@ def init_optimizer_state(workload: spec.Workload,
betas=(hyperparameters.beta1, hyperparameters.beta2),
eps=epsilon,
weight_decay=hyperparameters.weight_decay,
fused=False),
),
}

target_setting_step_hint = int(0.75 * workload.step_hint)
Expand Down
2 changes: 1 addition & 1 deletion submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def train_once(
model_params, model_state = workload.init_model_fn(
model_init_rng, dropout_rate, aux_dropout_rate)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = ['ogbg', 'criteo1tb']
compile_error_workloads = ['ogbg', 'criteo1tb', 'librispeech_conformer']
eager_backend_workloads = [
'librispeech_conformer', 'librispeech_deepspeech'
]
Expand Down

0 comments on commit cd7ee84

Please sign in to comment.