Source code for maldideepkit.utils.sam
"""Sharpness-Aware Minimization (SAM) optimizer wrapper.
SAM pushes weights toward flatter regions of the loss landscape, at
the cost of ~2x the forward / backward compute per step.
Usage
-----
.. code-block:: python
optimizer = SAMOptimizer(
model.parameters(), base_optimizer=torch.optim.AdamW,
rho=0.05, lr=1e-3, weight_decay=0.05,
)
loss = criterion(model(x), y)
loss.backward()
optimizer.first_step(zero_grad=True)
loss = criterion(model(x), y)
loss.backward()
optimizer.second_step(zero_grad=True)
"""
from __future__ import annotations
from typing import Any
import torch
[docs]
class SAMOptimizer(torch.optim.Optimizer):
"""Wrap a base optimizer in the SAM two-step update.
Parameters
----------
params : iterable
Parameters or param-group dicts (as for any torch optimizer).
base_optimizer : type
The base optimizer **class** (e.g. :class:`torch.optim.AdamW`).
Instantiated internally against the same param groups.
rho : float, default=0.05
Size of the ascent step in parameter space. Paper default is
``0.05``. Typical range: ``[0.01, 0.2]``.
**base_kwargs
Forwarded to the base optimizer (e.g. ``lr``, ``weight_decay``).
"""
[docs]
def __init__(
self,
params: Any,
base_optimizer: type[torch.optim.Optimizer],
rho: float = 0.05,
**base_kwargs: Any,
) -> None:
if rho <= 0:
raise ValueError(f"rho must be > 0; got {rho!r}.")
defaults = {"rho": float(rho), **base_kwargs}
super().__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **base_kwargs)
self.param_groups = self.base_optimizer.param_groups
for g in self.param_groups:
g.setdefault("rho", float(rho))
[docs]
@torch.no_grad()
def first_step(self, zero_grad: bool = False) -> None:
"""Ascend to ``w + e`` using the current gradients."""
grad_norm = self._grad_norm()
eps = 1e-12
for group in self.param_groups:
scale = group["rho"] / (grad_norm + eps)
for p in group["params"]:
if p.grad is None:
continue
e_w = p.grad * scale
self.state[p]["e_w"] = e_w
p.add_(e_w)
if zero_grad:
self.zero_grad()
[docs]
@torch.no_grad()
def second_step(self, zero_grad: bool = False) -> None:
"""Undo the ascent and apply the base optimizer step from ``w``."""
for group in self.param_groups:
for p in group["params"]:
if "e_w" in self.state.get(p, {}):
p.sub_(self.state[p]["e_w"])
del self.state[p]["e_w"]
self.base_optimizer.step()
self._step_count = getattr(self.base_optimizer, "_step_count", 0)
self._opt_called = True
if zero_grad:
self.zero_grad()
[docs]
def step(self, closure: Any = None) -> Any:
"""Unsupported. Use ``first_step`` / ``second_step`` instead."""
raise RuntimeError(
"SAMOptimizer requires an explicit two-pass training loop. "
"Call first_step() after the first backward, then recompute "
"the loss and backward, then call second_step()."
)
@torch.no_grad()
def _grad_norm(self) -> torch.Tensor:
ref_device = None
for group in self.param_groups:
for p in group["params"]:
if p.grad is not None:
ref_device = p.grad.device
break
if ref_device is not None:
break
if ref_device is None:
return torch.tensor(0.0)
norms = [
p.grad.norm(p=2).to(ref_device)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
return torch.norm(torch.stack(norms), p=2)