diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index c3a1b7b56..ab0ce1318 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -48,8 +48,6 @@ def __init__(self, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, - fused=False, - foreach=False, ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') @@ -63,7 +61,6 @@ def __init__(self, raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, - 'fused': fused, 'foreach': foreach, } super().__init__(params, defaults) @@ -211,7 +208,6 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters.beta2), eps=1e-8, weight_decay=hyperparameters.weight_decay, - fused=False, ) } diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index 58d39ae5b..b6e5ba61e 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -45,8 +45,6 @@ def __init__(self, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, - fused=False, - foreach=False, ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') @@ -60,7 +58,6 @@ def __init__(self, raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, - 'fused': fused, 'foreach': foreach, } super().__init__(params, defaults) @@ -208,7 +205,7 @@ def init_optimizer_state(workload: spec.Workload, betas=(hyperparameters.beta1, hyperparameters.beta2), eps=epsilon, weight_decay=hyperparameters.weight_decay, - fused=False), + ), } target_setting_step_hint = int(0.75 * workload.step_hint) diff --git a/submission_runner.py b/submission_runner.py index 2289d39d3..b7000d8ba 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -206,7 +206,7 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = ['ogbg', 'criteo1tb'] + compile_error_workloads = ['ogbg', 'criteo1tb', 'librispeech_conformer'] eager_backend_workloads = [ 'librispeech_conformer', 'librispeech_deepspeech' ]