diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index d5180af3c..354d89f1e 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -47,7 +47,8 @@ def __init__(self, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2): + weight_decay=1e-2, + fused=False,): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -59,7 +60,8 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, + 'fused': fused, } super().__init__(params, defaults) diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index f45454a27..d01d52b7a 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -44,7 +44,8 @@ def __init__(self, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2): + weight_decay=1e-2, + fused=False): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -56,7 +57,8 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, + 'fused': fused, } super().__init__(params, defaults)