From 7bb74a126b16beded2b846c2fd1e46fc922cc5c5 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 7 Dec 2024 20:32:37 +0100 Subject: [PATCH] fix utils --- benchmark/beale.py | 12 ++++++------ heavyball/utils.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/benchmark/beale.py b/benchmark/beale.py index e03eb20..16b2320 100644 --- a/benchmark/beale.py +++ b/benchmark/beale.py @@ -1,11 +1,12 @@ import itertools from typing import List -import heavyball import torch import torch.backends.opt_einsum import torch.nn as nn import typer + +import heavyball from heavyball.utils import set_torch from utils import trial @@ -25,9 +26,8 @@ def forward(self, inp): @app.command() def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(["float32"], help='Data type to use'), steps: int = 30_000, - weight_decay: float = 0, - opt: List[str] = typer.Option(['SFAdamW'], help='Optimizers to use')): + dtype: List[str] = typer.Option(["float32"], help='Data type to use'), steps: int = 10_000, + weight_decay: float = 0, opt: List[str] = typer.Option(heavyball.__all__, help='Optimizers to use')): dtype = [getattr(torch, d) for d in dtype] for args in itertools.product(method, dtype, opt, [weight_decay]): m, d, o, wd = args @@ -41,8 +41,8 @@ def data(): def win(_model, loss): return loss < 1e-5 - trial(model, data, torch.nn.functional.mse_loss, win, steps, o, d, 1, 1, wd, m, 1, 1, group=5_000, - base_lr=1e-4, trials=30) + trial(model, data, torch.nn.functional.mse_loss, win, steps, o, d, 1, 1, wd, m, 1, 1, group=1_000, base_lr=1e-4, + trials=30) if __name__ == '__main__': diff --git a/heavyball/utils.py b/heavyball/utils.py index b6b2379..92d184b 100644 --- a/heavyball/utils.py +++ b/heavyball/utils.py @@ -726,7 +726,7 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool): y, exp_avg, exp_avg_sq, grad = map(list_guard, (y, exp_avg, exp_avg_sq, grad)) - beta1, beta2, step, lr, eps = map(scalar_guard, (beta1, beta2, step, lr, eps), y[0]) + beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)] return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution) @@ -773,7 +773,7 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool): y, exp_avg, exp_avg_sq, grad_projected = map(list_guard, (y, exp_avg, exp_avg_sq, grad_projected)) - beta1, beta2, step, lr = map(scalar_guard, (beta1, beta2, step, lr), y[0]) + beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)] _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step, lr, decay, caution) @@ -798,7 +798,7 @@ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, l def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution): y, grad, exp_avg_sq, exp_avg = list_guard(y), list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg) - beta1, beta2, step, lr, eps = map(scalar_guard, (beta1, beta2, step, lr, eps), y[0]) + beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)] _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)