Skip to content

Commit

Permalink
fix utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ClashLuke committed Dec 7, 2024
1 parent 459d7d6 commit 7bb74a1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions benchmark/beale.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import itertools
from typing import List

import heavyball
import torch
import torch.backends.opt_einsum
import torch.nn as nn
import typer

import heavyball
from heavyball.utils import set_torch
from utils import trial

Expand All @@ -25,9 +26,8 @@ def forward(self, inp):

@app.command()
def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'),
dtype: List[str] = typer.Option(["float32"], help='Data type to use'), steps: int = 30_000,
weight_decay: float = 0,
opt: List[str] = typer.Option(['SFAdamW'], help='Optimizers to use')):
dtype: List[str] = typer.Option(["float32"], help='Data type to use'), steps: int = 10_000,
weight_decay: float = 0, opt: List[str] = typer.Option(heavyball.__all__, help='Optimizers to use')):
dtype = [getattr(torch, d) for d in dtype]
for args in itertools.product(method, dtype, opt, [weight_decay]):
m, d, o, wd = args
Expand All @@ -41,8 +41,8 @@ def data():
def win(_model, loss):
return loss < 1e-5

trial(model, data, torch.nn.functional.mse_loss, win, steps, o, d, 1, 1, wd, m, 1, 1, group=5_000,
base_lr=1e-4, trials=30)
trial(model, data, torch.nn.functional.mse_loss, win, steps, o, d, 1, 1, wd, m, 1, 1, group=1_000, base_lr=1e-4,
trials=30)


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions heavyball/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
y, exp_avg, exp_avg_sq, grad = map(list_guard, (y, exp_avg, exp_avg_sq, grad))
beta1, beta2, step, lr, eps = map(scalar_guard, (beta1, beta2, step, lr, eps), y[0])
beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)


Expand Down Expand Up @@ -773,7 +773,7 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
y, exp_avg, exp_avg_sq, grad_projected = map(list_guard, (y, exp_avg, exp_avg_sq, grad_projected))
beta1, beta2, step, lr = map(scalar_guard, (beta1, beta2, step, lr), y[0])
beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step, lr, decay, caution)


Expand All @@ -798,7 +798,7 @@ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, l

def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
y, grad, exp_avg_sq, exp_avg = list_guard(y), list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
beta1, beta2, step, lr, eps = map(scalar_guard, (beta1, beta2, step, lr, eps), y[0])
beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
_fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)


Expand Down

0 comments on commit 7bb74a1

Please sign in to comment.