Skip to content

Commit

Permalink
add sophiag optimizer (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
ad12 authored Aug 3, 2023
1 parent 450c1af commit 74447ee
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 1 deletion.
5 changes: 4 additions & 1 deletion meddlr/solver/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from meddlr.config import CfgNode

from .lr_scheduler import NoOpLR, WarmupCosineLR, WarmupMultiStepLR
from .optimizer import GradAccumOptimizer
from .optimizer import GradAccumOptimizer, SophiaG


def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
Expand Down Expand Up @@ -36,6 +36,9 @@ def _build_opt(params, cfg):
optimizer = torch.optim.SGD(params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM)
elif optim == "Adam":
optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR)
elif optim == "SophiaG":
# weight decay handled by build_optimizer
optimizer = SophiaG(params, cfg.SOLVER.BASE_LR, weight_decay=0.0)
else:
raise ValueError(f"Optimizer {optim} not supported")

Expand Down
233 changes: 233 additions & 0 deletions meddlr/solver/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from typing import List

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer

__all__ = ["GradAccumOptimizer"]
Expand Down Expand Up @@ -59,3 +62,233 @@ def _step(self, closure=None):
def __getattr__(self, item):
if hasattr(self.optimizer, item):
return getattr(self.optimizer, item)


class SophiaG(Optimizer):
"""
Adapted from https://github.com/Liuhong99/Sophia/blob/main/sophia.py.
"""

def __init__(
self,
params,
lr=1e-4,
betas=(0.965, 0.99),
rho=0.04,
weight_decay=1e-1,
*,
maximize: bool = False,
capturable: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= rho:
raise ValueError("Invalid rho parameter at index 1: {}".format(rho))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(
lr=lr,
betas=betas,
rho=rho,
weight_decay=weight_decay,
maximize=maximize,
capturable=capturable,
)
super(SophiaG, self).__init__(params, defaults)

def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("maximize", False)
group.setdefault("capturable", False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]["step"])
if not step_is_tensor:
for s in state_values:
s["step"] = torch.tensor(float(s["step"]))

@torch.no_grad()
def update_hessian(self):
for group in self.param_groups:
beta1, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]

if len(state) == 0:
state["step"] = (
torch.zeros((1,), dtype=torch.float, device=p.device)
if self.defaults["capturable"]
else torch.tensor(0.0)
)
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["hessian"] = torch.zeros_like(p, memory_format=torch.preserve_format)

if "hessian" not in state.keys():
state["hessian"] = torch.zeros_like(p, memory_format=torch.preserve_format)

state["hessian"].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)

@torch.no_grad()
def step(self, closure=None, bs=5120):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
state_steps = []
hessian = []
beta1, beta2 = group["betas"]

for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)

if p.grad.is_sparse:
raise RuntimeError("Hero does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = (
torch.zeros((1,), dtype=torch.float, device=p.device)
if self.defaults["capturable"]
else torch.tensor(0.0)
)
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["hessian"] = torch.zeros_like(p, memory_format=torch.preserve_format)

if "hessian" not in state.keys():
state["hessian"] = torch.zeros_like(p, memory_format=torch.preserve_format)

exp_avgs.append(state["exp_avg"])
state_steps.append(state["step"])
hessian.append(state["hessian"])

if self.defaults["capturable"]:
bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs

sophiag(
params_with_grad,
grads,
exp_avgs,
hessian,
state_steps,
bs=bs,
beta1=beta1,
beta2=beta2,
rho=group["rho"],
lr=group["lr"],
weight_decay=group["weight_decay"],
maximize=group["maximize"],
capturable=group["capturable"],
)

return loss


def sophiag(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
hessian: List[Tensor],
state_steps: List[Tensor],
capturable: bool = False,
*,
bs: int,
beta1: float,
beta2: float,
rho: float,
lr: float,
weight_decay: float,
maximize: bool,
):

if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)

func = _single_tensor_sophiag

func(
params,
grads,
exp_avgs,
hessian,
state_steps,
bs=bs,
beta1=beta1,
beta2=beta2,
rho=rho,
lr=lr,
weight_decay=weight_decay,
maximize=maximize,
capturable=capturable,
)


def _single_tensor_sophiag(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
hessian: List[Tensor],
state_steps: List[Tensor],
*,
bs: int,
beta1: float,
beta2: float,
rho: float,
lr: float,
weight_decay: float,
maximize: bool,
capturable: bool,
):

for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
hess = hessian[i]
step_t = state_steps[i]

if capturable:
assert param.is_cuda and step_t.is_cuda and bs.is_cuda

if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
hess = torch.view_as_real(hess)
param = torch.view_as_real(param)

# update step
step_t += 1

# Perform stepweight decay
param.mul_(1 - lr * weight_decay)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

if capturable:
step = step_t
step_size = lr
step_size_neg = step_size.neg()

ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1)
param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
else:
step = step_t.item() # noqa: F841
step_size_neg = -lr

ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1)
param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)

0 comments on commit 74447ee

Please sign in to comment.