diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index f3c748236..93e70002f 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -764,7 +764,7 @@ def train(self): # optimize for optimizer in self.optimizers.values(): if cfg.visible_adam: - optimizer.step(visibility_mask, gaussian_cnt) + optimizer.step(visibility_mask) else: optimizer.step() optimizer.zero_grad(set_to_none=True) diff --git a/gsplat/optimizers/selective_adam.py b/gsplat/optimizers/selective_adam.py index 9bf37e1fe..e7decf7a9 100644 --- a/gsplat/optimizers/selective_adam.py +++ b/gsplat/optimizers/selective_adam.py @@ -37,7 +37,7 @@ class SelectiveAdam(torch.optim.Adam): >>> loss.backward() >>> # Optimization step with selective updates - >>> optimizer.step(visibility=visibility_mask, N=N) + >>> optimizer.step(visibility=visibility_mask) """ @@ -45,7 +45,8 @@ def __init__(self, params, eps, betas): super().__init__(params=params, eps=eps, betas=betas) @torch.no_grad() - def step(self, visibility, N): + def step(self, visibility): + N = visibility.numel() for group in self.param_groups: lr = group["lr"] eps = group["eps"]