Skip to content

Commit

Permalink
Add LARS optimiser.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 378143672
  • Loading branch information
mtthss authored and OptaxDev committed Jun 8, 2021
1 parent 4f23bbd commit 2781a6b
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 6 deletions.
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from optax._src.alias import dpsgd
from optax._src.alias import fromage
from optax._src.alias import lamb
from optax._src.alias import lars
from optax._src.alias import noisy_sgd
from optax._src.alias import radam
from optax._src.alias import rmsprop
Expand Down Expand Up @@ -192,6 +193,7 @@
"InjectHyperparamsState",
"join_schedules",
"lamb",
"lars",
"log_cosh",
"lookahead",
"LookaheadParams",
Expand Down
59 changes: 57 additions & 2 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from optax._src import combine
from optax._src import privacy
from optax._src import transform
from optax._src import wrappers


ScalarOrSchedule = Union[float, base.Schedule]
MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]


def _scale_by_learning_rate(learning_rate: ScalarOrSchedule):
Expand Down Expand Up @@ -214,13 +216,62 @@ def fromage(
)


def lars(
learning_rate: ScalarOrSchedule,
weight_decay: float = 0.,
weight_decay_mask: MaskOrFn = True,
trust_coefficient: float = 0.001,
eps: float = 0.,
trust_ratio_mask: MaskOrFn = True,
momentum: float = 0.9,
nesterov: bool = False,
) -> base.GradientTransformation:
"""The LARS optimiser.
LAMB is a layer-wise adaptive optimiser introduced to help scale SGD to
larger batch sizes. LARS later inspired the LAMB optimiser.
References:
You et al, 2017: https://arxiv.org/abs/1708.03888
Args:
learning_rate: this is a fixed global scaling factor.
weight_decay (default `0.`): strength of the weight decay regularization.
weight_decay_mask: a tree with same structure as (or a prefix of) the params
PyTree, or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the transformation to, and `False` for those you want to skip.
trust_coefficient: a multiplier for the trust ratio.
eps: optional additive constant in the trust ratio denominator.
trust_ratio_mask: a tree with same structure as (or a prefix of) the params
PyTree, or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the transformation to, and `False` for those you want to skip.
momentum: the decay rate for momentum.
nesterov: whether to use Nesterov momentum.
Returns:
the corresponding `GradientTransformation`.
"""
return combine.chain(
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
wrappers.masked(
inner=transform.scale_by_trust_ratio(
trust_coefficient=trust_coefficient, eps=eps),
mask=trust_ratio_mask),
_scale_by_learning_rate(learning_rate),
transform.trace(decay=momentum, nesterov=nesterov),
)


def lamb(
learning_rate: ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-6,
eps_root: float = 0.0,
weight_decay: float = 0.
weight_decay: float = 0.,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
"""The LAMB optimiser.
Expand All @@ -243,13 +294,17 @@ def lamb(
the square root (as in RMSProp), to avoid dividing by zero when rescaling.
This is needed for instance when computing (meta-)gradients through Adam.
weight_decay (default `0.`): strength of the weight decay regularization.
mask: a tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the transformation to, and `False` for those you want to skip.
Returns:
the corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
transform.add_decayed_weights(weight_decay),
transform.add_decayed_weights(weight_decay=weight_decay, mask=mask),
transform.scale_by_trust_ratio(),
_scale_by_learning_rate(learning_rate),
)
Expand Down
13 changes: 11 additions & 2 deletions optax/_src/equivalence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ class FlaxOptimizersEquivalenceTest(chex.TestCase):

def setUp(self):
super().setUp()
self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))
self.init_params = (
jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.]))
self.per_step_updates = (
jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.]))

@parameterized.named_parameters(
('sgd',
Expand Down Expand Up @@ -125,6 +127,13 @@ def setUp(self):
('lamb',
alias.lamb(LR),
optim.LAMB(LR)),
('lars',
alias.lars(
LR, weight_decay=.5, trust_coefficient=0.003,
momentum=0.9, eps=1e-3),
optim.LARS(
LR, weight_decay=.5, trust_coefficient=0.003,
beta=0.9, eps=1e-3)),
)
def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

Expand Down
8 changes: 6 additions & 2 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,9 @@ class ScaleByTrustRatioState(NamedTuple):


def scale_by_trust_ratio(
min_norm: float = 0.0
min_norm: float = 0.0,
trust_coefficient: float = 1.,
eps: float = 0.,
) -> base.GradientTransformation:
"""Scale updates by trust ratio`.
Expand All @@ -626,6 +628,8 @@ def scale_by_trust_ratio(
Args:
min_norm: minimum norm for params and gradient norms; by default is zero.
trust_coefficient: a multiplier for the trust ratio.
eps: additive constant added to the denominator for numerical stability.
Returns:
An (init_fn, update_fn) tuple.
Expand All @@ -643,7 +647,7 @@ def _scale_update(update, param):
# Clip norms to minimum value, by default no clipping.
param_norm = numerics.safe_norm(param, min_norm)
update_norm = numerics.safe_norm(update, min_norm)
trust_ratio = param_norm / update_norm
trust_ratio = trust_coefficient * param_norm / (update_norm + eps)

# If no minimum norm clipping is used
# Set trust_ratio to 1 in case where parameters would never be updated.
Expand Down

0 comments on commit 2781a6b

Please sign in to comment.