Skip to content

Commit

Permalink
add sam optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Jun 19, 2024
1 parent b49e19d commit 4575a8b
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = 'r50_CE_none_bs100.py'

data = dict(
train=dict(data_source=dict(repeat=2)), # repeat 2 times
)

# optimizer
optimizer = dict(
type='SAMAdam',
lr=5e-4, weight_decay=0.0,
rho=0.05, adaptive=False,
paramwise_options={
'(bn|ln|gn)(\d+)?.(weight|bias)': dict(weight_decay=0.),
'norm': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
})

optimizer_config = dict(update_interval=2) # repeat 2 times
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = 'r50_CE_none_bs100.py'

data = dict(
train=dict(data_source=dict(repeat=2)), # repeat 2 times
)

# optimizer
optimizer = dict(
type='AdamW',
lr=1e-3, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999),
rho=0.05, adaptive=False,
paramwise_options={
'(bn|ln|gn)(\d+)?.(weight|bias)': dict(weight_decay=0.),
'norm': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
})

optimizer_config = dict(update_interval=2) # repeat 2 times
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = 'r50_CE_none_bs100.py'

data = dict(
train=dict(data_source=dict(repeat=2)), # repeat 2 times
)

# optimizer
optimizer = dict(
type='SAMSGD', lr=0.1, momentum=0.9, weight_decay=0.0001,
rho=0.05, adaptive=False)

optimizer_config = dict(update_interval=2) # repeat 2 times
3 changes: 2 additions & 1 deletion docs/en/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Changelog

### v0.2.9 (23/12/2023)
### v0.2.9 (23/12/2023 till now)

Bump version to V0.2.9 with new mixup augmentations and various optimizers.

Expand All @@ -11,6 +11,7 @@ Bump version to V0.2.9 with new mixup augmentations and various optimizers.
- Support more PyTorch optimizers implemented, including Adam variants (e.g., AdaBelief, AdaFactor) and SGD variants (e.g., SGDP).
- Support evaluation tools for mixup augmentations, including robustness testing (corruption and adversiral attack robustness) and calibration evaluation.
- Provide more config files for self-supervised learning methods on small-scale datasets (CIFAR-100 and STL-10).
- Support [Sharpness-Aware Minimization (SAM)](https://openreview.net/forum?id=6Tm1mposlrM) optimizer variants for small-scale datasets.

### v0.2.8 (25/05/2023)

Expand Down
6 changes: 6 additions & 0 deletions openmixup/core/hooks/optimizer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def after_train_iter(self, runner):
# update
runner.optimizer.step()
runner.optimizer.zero_grad()
else:
if getattr(runner.optimizer, 'base_optimizer', None):
# first forward-backward pass for SAM optimizers
if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters())
runner.optimizer.first_step(zero_grad=True)


if (TORCH_VERSION != 'parrots'
Expand Down
3 changes: 2 additions & 1 deletion openmixup/core/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from .lion import Lion
from .madgrad import MADGRAD
from .nvnovograd import NvNovoGrad
from .sam import SAM
from .sgdp import SGDP
from .sophia import SophiaG

__all__ = [
'AdaBelief', 'AdaBound', 'AdaBoundW', 'Adafactor', 'Adahessian', 'AdamP', 'Adan',
'LARS', 'LAMB', 'Lion', 'MADGRAD', 'NvNovoGrad', 'SGDP', 'SophiaG',
'LARS', 'LAMB', 'Lion', 'MADGRAD', 'NvNovoGrad', 'SAM', 'SGDP', 'SophiaG',
'build_optimizer', 'DefaultOptimizerConstructor', 'TransformerFinetuneConstructor'
]
12 changes: 8 additions & 4 deletions openmixup/core/optimizers/lars.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(self,
dampening=0,
weight_decay=0,
eta=0.001,
nesterov=False):
nesterov=False,
none_zero=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
Expand All @@ -48,6 +49,7 @@ def __init__(self,
"Invalid weight_decay value: {}".format(weight_decay))
if eta < 0.0:
raise ValueError("Invalid LARS coefficient value: {}".format(eta))
self.none_zero = none_zero

defaults = dict(
lr=lr, momentum=momentum, dampening=dampening,
Expand Down Expand Up @@ -96,9 +98,11 @@ def step(self, closure=None):
weight_norm = torch.norm(p).item()
grad_norm = torch.norm(d_p).item()
# Compute local learning rate for this layer
local_lr = eta * weight_norm / \
(grad_norm + weight_decay * weight_norm)

local_lr = 1.
if self.none_zero:
if grad_norm != 0 or grad_norm != 0:
local_lr = eta * weight_norm / \
(grad_norm + weight_decay * weight_norm)
actual_lr = local_lr * lr
d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr)
if momentum != 0:
Expand Down
108 changes: 108 additions & 0 deletions openmixup/core/optimizers/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
from mmcv.runner.optimizer.builder import OPTIMIZERS
from torch.optim.optimizer import Optimizer, required


class SAM(Optimizer):
r"""Sharpness-Aware Minimization (SAM) optimizer.
Implementation of `Sharpness-Aware Minimization for Efficiently Improving Generalization
(ICLR'2021) <https://openreview.net/forum?id=6Tm1mposlrM>`_.
https://github.com/davda54/sam
https://github.com/google-research/sam
"""
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)

self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)

@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)

for p in group["params"]:
if p.grad is None: continue
self.state[p]["old_p"] = p.data.clone()
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"

if zero_grad: self.zero_grad()

@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None: continue
p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"

self.base_optimizer.step() # do the actual "sharpness-aware" update

if zero_grad: self.zero_grad()

@torch.no_grad()
def step(self, closure=None):
# assert closure is not None, \
# "Sharpness Aware Minimization requires closure, but it was not provided"
if closure is not None:
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
self.first_step(zero_grad=True)
closure()
self.second_step(zero_grad=True)
else:
self.second_step(zero_grad=True)

def _grad_norm(self):
# put everything on the same device, in case of model parallelism
shared_device = self.param_groups[0]["params"][0].device
norm = torch.norm(
torch.stack([
((torch.abs(p) if group["adaptive"] else 1.0
) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
p=2)
return norm

def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups


@OPTIMIZERS.register_module()
class SAMAdam(SAM):
def __init__(self, params, lr=required,
betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False,
rho=0.05, adaptive=False, **kwargs):
defaults_opt = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
super().__init__(
params=params, base_optimizer=torch.optim.Adam, rho=rho, adaptive=adaptive, **defaults_opt)


@OPTIMIZERS.register_module()
class SAMAdamW(SAM):
def __init__(self, params, lr=required,
betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False,
rho=0.05, adaptive=False, **kwargs):
defaults_opt = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
super().__init__(
params=params, base_optimizer=torch.optim.AdamW, rho=rho, adaptive=adaptive, **defaults_opt)


@OPTIMIZERS.register_module()
class SAMSGD(SAM):
def __init__(self, params, lr=required,
momentum=0, dampening=0, weight_decay=0, nesterov=False,
rho=0.05, adaptive=False):
defaults_opt = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
super().__init__(
params=params, base_optimizer=torch.optim.SGD, rho=rho, adaptive=adaptive, **defaults_opt)
17 changes: 12 additions & 5 deletions openmixup/datasets/data_sources/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ class Cifar(metaclass=ABCMeta):

CLASSES = None

def __init__(self, root, split, return_label=True, num_labeled=None):
def __init__(self, root, split, return_label=True, num_labeled=None, repeat=1):
assert split in ['train', 'test']
self.root = root
self.split = split
self.return_label = return_label
self.num_labeled = num_labeled
self.repeat = int(repeat)
self.cifar = None
self.set_cifar()
self.labels = self.cifar.targets
Expand Down Expand Up @@ -60,13 +61,16 @@ class CIFAR10(Cifar):
'horse', 'ship', 'truck'
]

def __init__(self, root, split, return_label=True, num_labeled=None):
super().__init__(root, split, return_label, num_labeled)
def __init__(self, root, split, return_label=True, num_labeled=None, repeat=1):
super().__init__(root, split, return_label, num_labeled, repeat)

def set_cifar(self):
try:
self.cifar = torchvision.datasets.CIFAR10(
root=self.root, train=self.split == 'train', download=False)
if self.repeat > 1:
self.cifar.data = np.concatenate([self.cifar.data] * self.repeat)
self.cifar.targets = np.concatenate([self.cifar.targets] * self.repeat)
except:
raise Exception("Please download CIFAR10 manually, \
in case of downloading the dataset parallelly \
Expand Down Expand Up @@ -132,13 +136,16 @@ class CIFAR100(Cifar):
'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'
]

def __init__(self, root, split, return_label=True, num_labeled=None):
super().__init__(root, split, return_label, num_labeled)
def __init__(self, root, split, return_label=True, num_labeled=None, repeat=1):
super().__init__(root, split, return_label, num_labeled, repeat)

def set_cifar(self):
try:
self.cifar = torchvision.datasets.CIFAR100(
root=self.root, train=self.split == 'train', download=False)
if self.repeat > 1:
self.cifar.data = np.concatenate([self.cifar.data] * self.repeat)
self.cifar.targets = np.concatenate([self.cifar.targets] * self.repeat)
except:
raise Exception("Please download CIFAR10 manually, \
in case of downloading the dataset parallelly \
Expand Down
31 changes: 22 additions & 9 deletions openmixup/datasets/data_sources/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ class Mnist_base(metaclass=ABCMeta):

CLASSES = None

def __init__(self, root, split, return_label=True):
def __init__(self, root, split, return_label=True, repeat=1):
assert split in ['train', 'test']
self.root = root
self.split = split
self.return_label = return_label
self.repeat = int(repeat)
self.mnist = None
self.set_mnist()
self.labels = self.mnist.targets
Expand Down Expand Up @@ -50,13 +51,16 @@ class USPS(Mnist_base):
CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

def __init__(self, root, split, return_label=True):
super().__init__(root, split, return_label)
def __init__(self, root, split, return_label=True, repeat=1):
super().__init__(root, split, return_label, repeat)

def set_mnist(self):
try:
self.mnist = torchvision.datasets.USPS(
root=self.root, train=self.split == 'train', download=False)
if self.repeat > 1:
self.mnist.data = np.concatenate([self.mnist.data] * self.repeat)
self.mnist.targets = np.concatenate([self.mnist.targets] * self.repeat)
except:
raise Exception("Please download USPS binary manually, \
in case of downloading the dataset parallelly \
Expand All @@ -69,13 +73,16 @@ class MNIST(Mnist_base):
CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

def __init__(self, root, split, return_label=True):
super().__init__(root, split, return_label)
def __init__(self, root, split, return_label=True, repeat=1):
super().__init__(root, split, return_label, repeat)

def set_mnist(self):
try:
self.mnist = torchvision.datasets.MNIST(
root=self.root, train=self.split == 'train', download=False)
if self.repeat > 1:
self.mnist.data = np.concatenate([self.mnist.data] * self.repeat)
self.mnist.targets = np.concatenate([self.mnist.targets] * self.repeat)
except:
raise Exception("Please download MNIST manually, \
in case of downloading the dataset parallelly \
Expand All @@ -88,13 +95,16 @@ class FMNIST(Mnist_base):
CLASSES = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

def __init__(self, root, split, return_label=True):
super().__init__(root, split, return_label)
def __init__(self, root, split, return_label=True, repeat=1):
super().__init__(root, split, return_label, repeat)

def set_mnist(self):
try:
self.mnist = torchvision.datasets.FashionMNIST(
root=self.root, train=self.split == 'train', download=False)
if self.repeat > 1:
self.mnist.data = np.concatenate([self.mnist.data] * self.repeat)
self.mnist.targets = np.concatenate([self.mnist.targets] * self.repeat)
except:
raise Exception("Please download FashionMNIST manually, \
in case of downloading the dataset parallelly \
Expand All @@ -106,13 +116,16 @@ class KMNIST(Mnist_base):

CLASSES = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']

def __init__(self, root, split, return_label=True):
super().__init__(root, split, return_label)
def __init__(self, root, split, return_label=True, repeat=1):
super().__init__(root, split, return_label, repeat)

def set_mnist(self):
try:
self.mnist = torchvision.datasets.KMNIST(
root=self.root, train=self.split == 'train', download=False)
if self.repeat > 1:
self.mnist.data = np.concatenate([self.mnist.data] * self.repeat)
self.mnist.targets = np.concatenate([self.mnist.targets] * self.repeat)
except:
raise Exception("Please download KMNIST manually, \
in case of downloading the dataset parallelly \
Expand Down

0 comments on commit 4575a8b

Please sign in to comment.