Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-goel committed Oct 1, 2024
1 parent 35412ec commit 464d064
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions gsplat/optimizers/selective_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,41 @@


class SelectiveAdam(torch.optim.Adam):
"""
A custom optimizer that extends the standard Adam optimizer by
incorporating selective updates.
This class is useful for situations where only a subset of parameters
should be updated at each step, such as in sparse models or in cases where
parameter visibility is controlled by an external mask.
Additionally, the operations are fused into a single kernel. This optimizer
leverages the `selective_adam_update` function from a CUDA backend for
optimized sparse updates.
Args:
params (iterable): Iterable of parameters to optimize or dicts defining parameter groups.
eps (float): Term added to the denominator to improve numerical stability (default: 1e-8).
betas (Tuple[float, float]): Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)).
Examples:
>>> N = 100
>>> param = torch.randn(N, requires_grad=True)
>>> optimizer = SelectiveAdam([param], eps=1e-8, betas=(0.9, 0.999))
>>> visibility_mask = torch.cat([torch.ones(50), torch.zeros(50)]) # Visible first half, hidden second half
>>> # Forward pass
>>> loss = torch.sum(param ** 2)
>>> # Backward pass
>>> loss.backward()
>>> # Optimization step with selective updates
>>> optimizer.step(visibility=visibility_mask, N=N)
"""

def __init__(self, params, eps, betas):
super().__init__(params=params, eps=eps, betas=betas)

Expand Down

0 comments on commit 464d064

Please sign in to comment.