From d63214bde21247d4aaabc5b6f7ad8868fe74efa0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 21 Sep 2023 23:23:16 +0000 Subject: [PATCH] set fused to false to for pytorch nadamw --- baselines/nadamw/pytorch/submission.py | 4 +++- .../target_setting_algorithms/pytorch_nadamw.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index 01cffc52e..d5180af3c 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -206,7 +206,9 @@ def init_optimizer_state(workload: spec.Workload, betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), eps=1e-8, - weight_decay=hyperparameters.weight_decay), + weight_decay=hyperparameters.weight_decay, + fused=False, + ) } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index 71b819e66..f45454a27 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -203,7 +203,8 @@ def init_optimizer_state(workload: spec.Workload, lr=hyperparameters.learning_rate, betas=(hyperparameters.beta1, hyperparameters.beta2), eps=epsilon, - weight_decay=hyperparameters.weight_decay), + weight_decay=hyperparameters.weight_decay, + fused=False), } target_setting_step_hint = int(0.75 * workload.step_hint)