From 77c1d60b18ec95c9de8e896295c21ef5e3165080 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 22 Sep 2023 00:34:21 +0000 Subject: [PATCH] set foreach false in nadamw --- baselines/nadamw/pytorch/submission.py | 6 ++++-- .../target_setting_algorithms/pytorch_nadamw.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index 354d89f1e..c3a1b7b56 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -48,7 +48,9 @@ def __init__(self, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, - fused=False,): + fused=False, + foreach=False, + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -61,7 +63,7 @@ def __init__(self, raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, - 'fused': fused, + 'fused': fused, 'foreach': foreach, } super().__init__(params, defaults) diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index d01d52b7a..58d39ae5b 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -45,7 +45,9 @@ def __init__(self, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, - fused=False): + fused=False, + foreach=False, + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -58,7 +60,7 @@ def __init__(self, raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, - 'fused': fused, + 'fused': fused, 'foreach': foreach, } super().__init__(params, defaults)