We present PyTorch code for Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term, KDD'23. The code is based on https://github.com/davda54/sam.
Deep Neural Networks (DNNs) generalization is known to be closely related to the flatness of minima, leading to the development of Sharpness-Aware Minimization (SAM) for seeking flatter minima and better generalization. We propose a more general method, called WSAM, by incorporating sharpness as a regularization term. WSAM can achieve improved generalization, or is at least highly competitive, compared to the vanilla optimizer, SAM and its variants.
WSAM can achieve different (flatter) minima by choosing different 𝛾.
Similar to SAM, WSAM can be used in a two-step manner or with a single closure-based function.
from atorch.optimizers.wsam import WeightedSAM
from atorch.optimizers.utils import enable_running_stats, disable_running_stats
...
model = YourModel()
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.001) # initialize the base optimizer
optimizer = WeightedSAM(model, base_optimizer, rho=0.05, gamma=0.9, adaptive=False, decouple=True, max_norm=None)
...
# 1. two-step method
for input, output in data:
enable_running_stats(model)
with model.no_sync():
# first forward-backward pass
loss = loss_function(output, model(input)) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
disable_running_stats(model)
# second forward-backward pass
loss_function(output, model(input)).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
...
# 2. closure-based method
for input, output in data:
def closure():
loss = loss_function(output, model(input))
loss.backward()
return loss
loss = loss_function(output, model(input))
loss.backward()
optimizer.step(closure)
optimizer.zero_grad()
...
- Regulatization mode: It is recommended to perform a decoupled update of the sharpness term, as used in our paper.
- Gradient clipping: To ensure training stability, if
max_norm
is notNone
, WSAM will perform gradient clipping. - Gradient sync: This implementation synchronizes gradients correctly, corresponding to the m-sharpness used in the SAM paper.
- Rho selection: If you try to reproduce ViT results from this paper, use a larger rho when having less GPUs. For more information, see this related link.